<a href="https://colab.research.google.com/github/qn19325/individualProject/blob/main/cnn_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

In [3]:
!pip install wandb --upgrade

Collecting wandb
  Downloading wandb-0.12.14-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 12.2 MB/s 
Collecting setproctitle
  Downloading setproctitle-1.2.2-cp37-cp37m-manylinux1_x86_64.whl (36 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.5.9-py2.py3-none-any.whl (144 kB)
[K     |████████████████████████████████| 144 kB 50.6 MB/s 
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 51.0 MB/s 
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 1.7 MB/s 
[?25hCollecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)


In [4]:
import wandb 

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [10]:
config = dict(
    epochs=1000,
    batch_size=8,
    learning_rate=0.001,
    input_size = 128,
    hidden_size = 128,
    output_size = 1,
    sequence_length = 2,
    num_layers = 1,
    dataset="basic64x64",
    architecture="CNN/RNN")

In [32]:
def model_pipeline(hyperparameters):

    # tell wandb to get started
    with wandb.init(project="individual_project", config=hyperparameters):
      # access all HPs through wandb.config, so logging matches execution!
      config = wandb.config

      # make the model, data, and optimization problem
      encoder, model, train_loader, criterion, optimizer = make(config)
      print(model)

      # and use them to train the model
      train(encoder, model, train_loader, criterion, optimizer, config)

    return model

In [31]:
def make(config):
    # Make the data
    train = get_data(train=True)
    train_loader = make_loader(train, batch_size=config.batch_size)

    # Make the CNN encoder
    encoder = Encoder().to(device)
    # Make the RNN model
    model = RNN(config).to(device)

    # Make the loss and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr=config.learning_rate)
    
    return encoder, model, train_loader, criterion, optimizer

In [13]:
def get_data(slice=5, train=True):
    full_dataset = ImageDataLoader()
    
    return full_dataset

def make_loader(dataset, batch_size):
    loader = DataLoader(dataset=dataset, 
                        batch_size=batch_size, 
                        shuffle=True)
    return loader

In [14]:
class ImageDataLoader(Dataset):
    def __init__(self, dir_=None):
        self.data_df = pd.read_csv('gdrive/MyDrive/64x64.csv')
        self.dataset_len = len(self.data_df) # read the number of len of your csv files
    def __getitem__(self, idx):
        # load the next image
        f_name_t = self.data_df['Filename'][idx]
        f_name_tp1 = self.data_df['Filename'][idx+1]
        label = self.data_df['Label'][idx]
        label = label.astype(np.float32) 
        label = np.true_divide(label, 20)
        img_t = torchvision.io.read_image('gdrive/MyDrive/64x64/{}'.format(f_name_t))
        img_tp1 = torchvision.io.read_image('gdrive/MyDrive/64x64/{}'.format(f_name_tp1))
        img_t = img_t.float().div_(255.0)
        img_tp1 = img_tp1.float().div_(255.0)
        return img_t, img_tp1, label
    def __len__(self):
        return self.dataset_len - 1

In [15]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, 1, 1),
            nn.ReLU(),
            nn.Conv2d(8, 16, 1, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc1 = nn.Linear(65536, 128)
    def forward(self, x):
        state = self.cnn(x)
        state = self.fc1(state)
        return state

In [21]:
class RNN(nn.Module):
    def __init__(self, config):
        super(RNN, self).__init__()
        self.batch_size = config.batch_size
        self.input_size = config.input_size
        self.num_layers = config.num_layers
        self.hidden_size = config.hidden_size
        self.output_size = config.output_size
        self.rnn = nn.RNN(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.fc = nn.Linear(self.hidden_size, self.output_size)
    def init_hidden(self):
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_size).to(device))
    def forward(self, x):
        self.batch_size = x.size(0)
        self.hidden = self.init_hidden()
        out, self.hidden = self.rnn(x, self.hidden)
        out = self.fc(out)
        return out

In [34]:
def train(encoder, model, loader, criterion, optimizer, config):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, criterion, log="all", log_freq=10)

    # Run training and track with wandb
    total_batches = len(loader) * config.epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0
    for epoch in tqdm(range(config.epochs)):
        for _, (images1, images2, labels) in enumerate(loader):

            loss = train_batch(images1, images2, labels, encoder, model, optimizer, criterion)
            example_ct += len(images1)
            batch_ct += 1

            # Report metrics every 25th batch
            if ((batch_ct + 1) % 25) == 0:
                train_log(loss, example_ct, epoch)


def train_batch(images1, images2, labels, encoder, model, optimizer, criterion):
    images1, images2, labels = images1.to(device), images2.to(device), (labels.float()).to(device)
    
    # Forward pass ➡
    # pass to encoder
    output1 = encoder(images1)
    output2 = encoder(images2)
    # pass to RNN
    batch_size1 = len(output1)
    batch_size2 = len(output2)

    output1 = output1.reshape(batch_size1,1,-1)
    output2 = output2.reshape(batch_size2,1,-1)
    
    seq = torch.cat((output1, output2.detach()), dim=1)

    outputs = model(seq.to(device))
    loss = criterion(outputs[:,-1].squeeze(), labels.float())

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()

    # Step with optimizer
    optimizer.step()

    return loss

In [18]:
def train_log(loss, example_ct, epoch):
    # Where the magic happens
    wandb.log({"epoch": epoch, "loss": loss}, step=example_ct)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")

In [35]:
model = model_pipeline(config)

RNN(
  (rnn): RNN(128, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)


  0%|          | 0/1000 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss after 08344 examples: 0.137
Loss after 08544 examples: 0.064
Loss after 08744 examples: 0.076
Loss after 08944 examples: 0.042
Loss after 09138 examples: 0.088
Loss after 09338 examples: 0.128
Loss after 09538 examples: 0.080
Loss after 09738 examples: 0.057
Loss after 09938 examples: 0.091
Loss after 10132 examples: 0.082
Loss after 10332 examples: 0.129
Loss after 10532 examples: 0.074
Loss after 10732 examples: 0.159
Loss after 10932 examples: 0.093
Loss after 11126 examples: 0.099
Loss after 11326 examples: 0.093
Loss after 11526 examples: 0.090
Loss after 11726 examples: 0.133
Loss after 11926 examples: 0.136
Loss after 12120 examples: 0.106
Loss after 12320 examples: 0.058
Loss after 12520 examples: 0.059
Loss after 12720 examples: 0.053
Loss after 12920 examples: 0.080
Loss after 13114 examples: 0.058
Loss after 13314 examples: 0.047
Loss after 13514 examples: 0.094
Loss after 13714 examples: 0.083
Loss after 

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,▅▆██▅▅▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,999.0
loss,0.00072
