In [1]:
from turtle import distance
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import random 
import math
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import pickle
import cv2

In [2]:
# Tracking training
import wandb

# wandb.init(project="nih_animal_behavior", entity="serrelab")

In [3]:
data_dir = 'dataset'

train_dataset_list = pickle.load(open('/home/anagara8/Documents/Autoencoder/postext_frames.pickle', 'rb'))
test_dataset_list = train_dataset_list

In [10]:
number_of_frames_per_video = 2
real_train_data = []
dimension = (224, 224)

for data in train_dataset_list:
    temp = []
    if len(data) == number_of_frames_per_video:
        for frame_index in range(number_of_frames_per_video):#cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
            temp.append(cv2.resize(cv2.cvtColor(data[frame_index], cv2.COLOR_BGR2GRAY), dimension, interpolation = cv2.INTER_AREA))
    
    else:
        continue
    real_train_data += temp

In [11]:
print(len(train_dataset_list), len(train_dataset_list[0]),len(train_dataset_list[0][0]), len(train_dataset_list[0][0][0]))

246 2 480 640


In [12]:
print(len(real_train_data), len(real_train_data[0]), len(real_train_data[0][0]))

492 224 224


In [13]:
train_dataset_list = real_train_data
test_dataset_list = real_train_data

In [14]:
training_data_tensor = torch.Tensor(train_dataset_list)
testing_data_tensor = torch.Tensor(test_dataset_list)
print("Converted to Tensors")

Converted to Tensors


In [15]:
training_dataset = TensorDataset(training_data_tensor)
testing_dataset = TensorDataset(testing_data_tensor)
print("Converted to Tensor Dataset")

Converted to Tensor Dataset


In [16]:
train_dataset = DataLoader(training_dataset) # .dataset
test_dataset = DataLoader(testing_dataset)
print("Converted to DataLoader")

Converted to DataLoader


In [17]:
print(len(training_dataset), "x", len(training_dataset[0]))
for i in training_dataset:
    print(i)
    break

492 x 1
(tensor([[216., 157., 168.,  ..., 188., 189., 188.],
        [216., 218., 216.,  ..., 189., 190., 188.],
        [220., 219., 218.,  ..., 189., 189., 189.],
        ...,
        [205., 207., 207.,  ..., 199., 199., 199.],
        [207., 208., 208.,  ..., 201., 201., 201.],
        [208., 209., 208.,  ..., 201., 200., 201.]]),)


### MNIST Data

In [18]:
train_dataset_mnist = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset_mnist  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Resize((224, 224))
])

test_transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Resize((224, 224))
])

train_dataset_mnist.transform = train_transform
test_dataset_mnist.transform = test_transform


### Train-Validation Split

In [19]:
m=len(train_dataset)

print("Length:",m, int(math.ceil(m-m*0.2)), int(m*0.2))
train_data, val_data = random_split(train_dataset, [int(math.ceil(m-m*0.2)), int(m*0.2)])
batch_size=256

Length: 492 394 98


In [20]:
m2=len(train_dataset_mnist)

print("Length:", m2, int(m2-m2*0.2), int(math.ceil(m2*0.2)))
train_data_mnist, val_data_mnist = random_split(train_dataset_mnist, [int(m2-m2*0.2), int(math.ceil(m2*0.2))])
batch_size=256

Length: 60000 48000 12000


### All datasets as Dataloaders

In [21]:
train_loader_mnist = torch.utils.data.DataLoader(train_data_mnist, batch_size=batch_size)
valid_loader_mnist = torch.utils.data.DataLoader(val_data_mnist, batch_size=batch_size)
test_loader_mnist = torch.utils.data.DataLoader(test_dataset_mnist, batch_size=batch_size,shuffle=True)

In [22]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

### ResNet-18 Model

In [23]:
class ResizeConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        return x

class BasicBlockEnc(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class BasicBlockDec(nn.Module):

    def __init__(self, in_planes, stride=1):
        super().__init__()

        planes = int(in_planes/stride)

        self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(in_planes)
        # self.bn1 could have been placed here, but that messes up the order of the layers when printing the class

        if stride == 1:
            self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential()
        else:
            self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.shortcut = nn.Sequential(
                ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn2(self.conv2(x)))
        out = self.bn1(self.conv1(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class ResNet18Enc(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 64
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)
        self.linear = nn.Linear(512, 2 * z_dim)

    def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in strides:
            layers += [BasicBlockEnc(self.in_planes, stride)]
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        mu = x[:, :self.z_dim]
        logvar = x[:, self.z_dim:]
        return mu, logvar

class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 512

        self.linear = nn.Linear(z_dim, 512)

        self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
        self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
        self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
        self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
        self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, z):
        x = self.linear(z)
        x = x.view(z.size(0), 512, 1, 1)
        x = F.interpolate(x, scale_factor=4)
        x = self.layer4(x)
        x = self.layer3(x)
        x = self.layer2(x)
        x = self.layer1(x)
        x = torch.sigmoid(self.conv1(x))
        x = x.view(x.size(0), 3, 64, 64)
        return x

class VAE(nn.Module):

    def __init__(self, z_dim):
        super().__init__()
        self.encoder = ResNet18Enc(z_dim=z_dim)
        self.decoder = ResNet18Dec(z_dim=z_dim)

    def forward(self, x):
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, z
    
    @staticmethod
    def reparameterize(mean, logvar):
        std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
        epsilon = torch.randn_like(std)
        return epsilon * std + mean

### Shallow Autoencoder

In [38]:
class Encoder(nn.Module):
    # ResNet-18
    # U-Net Backbone
    # Save the failed models
    # Auto-regression on the latent space
    def __init__(self, encoded_space_dim, fc2_input_dim):
        super().__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_lin(x)
        return x
class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim,fc2_input_dim):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)
        return x

In [39]:
loss_fn = torch.nn.MSELoss()

In [40]:
latent_vector_dimension = 16

In [41]:
encoder = Encoder(encoded_space_dim=latent_vector_dimension, fc2_input_dim=224)
decoder = Decoder(encoded_space_dim=latent_vector_dimension, fc2_input_dim=224)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

In [42]:
lr= 0.001
optim = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-05)

# Manual seed for reproducible results
torch.manual_seed(0)

<torch._C.Generator at 0x7f1f729c7a90>

In [43]:
# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

Selected device: cuda


Decoder(
  (decoder_lin): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(32, 16, 16))
  (decoder_conv): Sequential(
    (0): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(8, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
  )
)

In [44]:
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    train_loss = []
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
#     for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
    for image_batch in dataloader:
        # Move tensor to the proper device
        print(type(image_batch), image_batch)
        image_batch = image_batch[0].to(device)
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss
        loss = loss_fn(decoded_data, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        print('\t partial train loss (single batch): %f' % (loss.data))
        train_loss.append(loss.detach().cpu().numpy())

    return np.mean(train_loss)

### Testing function
def test_epoch(encoder, decoder, device, dataloader, loss_fn):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_fn(conc_out, conc_label)
    return val_loss.data

def plot_ae_outputs(encoder,decoder,n=10):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.show()

In [45]:
num_epochs = 500

wandb.config = {
  "learning_rate": lr,
  "epochs": num_epochs,
  "batch_size": batch_size
}

In [46]:
train_loader_mnist.__dict__

{'dataset': <torch.utils.data.dataset.Subset at 0x7f207fc89bb0>,
 'num_workers': 0,
 'prefetch_factor': 2,
 'pin_memory': False,
 'timeout': 0,
 'worker_init_fn': None,
 '_DataLoader__multiprocessing_context': None,
 '_dataset_kind': 0,
 'batch_size': 256,
 'drop_last': False,
 'sampler': <torch.utils.data.sampler.SequentialSampler at 0x7f207fc222e0>,
 'batch_sampler': <torch.utils.data.sampler.BatchSampler at 0x7f207fc22160>,
 'generator': None,
 'collate_fn': <function torch.utils.data._utils.collate.default_collate(batch)>,
 'persistent_workers': False,
 '_DataLoader__initialized': True,
 '_IterableDataset_len_called': None,
 '_iterator': None}

In [47]:
train_loader.__dict__

{'dataset': <torch.utils.data.dataset.Subset at 0x7f207fc366d0>,
 'num_workers': 0,
 'prefetch_factor': 2,
 'pin_memory': False,
 'timeout': 0,
 'worker_init_fn': None,
 '_DataLoader__multiprocessing_context': None,
 '_dataset_kind': 0,
 'batch_size': 256,
 'drop_last': False,
 'sampler': <torch.utils.data.sampler.SequentialSampler at 0x7f208fc05880>,
 'batch_sampler': <torch.utils.data.sampler.BatchSampler at 0x7f208fc055e0>,
 'generator': None,
 'collate_fn': <function torch.utils.data._utils.collate.default_collate(batch)>,
 'persistent_workers': False,
 '_DataLoader__initialized': True,
 '_IterableDataset_len_called': None,
 '_iterator': None}

In [48]:
examples = enumerate(test_loader.dataset)
batch_idx, example_data = next(examples)
print(batch_idx, example_data)

0 [tensor([[[216., 157., 168.,  ..., 188., 189., 188.],
         [216., 218., 216.,  ..., 189., 190., 188.],
         [220., 219., 218.,  ..., 189., 189., 189.],
         ...,
         [205., 207., 207.,  ..., 199., 199., 199.],
         [207., 208., 208.,  ..., 201., 201., 201.],
         [208., 209., 208.,  ..., 201., 200., 201.]]])]


In [36]:
diz_loss = {'train_loss':[],'val_loss':[]}
for epoch in range(num_epochs):
   train_loss =train_epoch(encoder, decoder, device, test_loader.dataset, loss_fn, optim)
   val_loss = test_epoch(encoder, decoder, device, test_loader.dataset, loss_fn)
   print('\n EPOCH {}/{} \t train loss {} \t val loss {}'.format(epoch + 1, num_epochs, train_loss,val_loss))
   diz_loss['train_loss'].append(train_loss)
   diz_loss['val_loss'].append(val_loss)
   wandb.log({"loss": val_loss})
   wandb.watch(encoder)  

<class 'list'> [tensor([[[216., 157., 168.,  ..., 188., 189., 188.],
         [216., 218., 216.,  ..., 189., 190., 188.],
         [220., 219., 218.,  ..., 189., 189., 189.],
         ...,
         [205., 207., 207.,  ..., 199., 199., 199.],
         [207., 208., 208.,  ..., 201., 201., 201.],
         [208., 209., 208.,  ..., 201., 200., 201.]]])]


RuntimeError: Expected 4-dimensional input for 4-dimensional weight [8, 1, 3, 3], but got 3-dimensional input of size [1, 224, 224] instead

In [None]:
test_epoch(encoder,decoder,device,test_loader,loss_fn).item()

In [None]:
# Plot losses
plt.figure(figsize=(10,8))
plt.semilogy(diz_loss['train_loss'], label='Train')
plt.semilogy(diz_loss['val_loss'], label='Valid')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
#plt.grid()
plt.legend()
#plt.title('loss')
plt.show()

In [None]:
### Random Reconstructed Images
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


In [None]:
encoder.eval()
decoder.eval()

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs 
    images, labels = iter(test_loader).next()
    images = images.to(device)
    latent = encoder(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    print(mean)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()
    print(std)

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, d)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(device)
    img_recon = decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()