In [0]:
from __future__ import print_function, division
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import torchvision

device = "cuda"

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
data_path = 'content/drive/My Drive'
load_path = 'sample_data'

In [0]:
!cp '/content/drive/My Drive/data.zip' /content
!unzip data.zip

In [0]:
def dict_to_device(dictionary, device):
    for k,v in dictionary.items():
        dictionary[k] = v.to(device)
    return dictionary

In [0]:
class Killer_Whale_Dataset(Dataset):
    # this is a class for the Killer whale dataset
    
    #first of all override the __init__() method.
    def __init__(self, data_folder,transform = None):
    	# super() method is to use the method in its parent class
        super().__init__()
        self.folder_list = os.listdir(data_folder)
        self.img_path = os.path.join(data_folder,'img/')
        self.mask_path = os.path.join(data_folder,'mask/')
        self.img_list = sorted(os.listdir(self.img_path))
        self.mask_list = sorted(os.listdir(self.mask_path))
        self.transform = transform

    def __getitem__(self,idx):
        self.img = Image.open(self.img_path+self.img_list[idx])
        self.mask = Image.open(self.mask_path+self.mask_list[idx])
        if self.transform:
            self.img = self.transform(self.img)
            self.img = nn.AdaptiveAvgPool2d((224,224))(self.img)
            self.mask = self.transform(self.mask)
            self.mask = nn.AdaptiveAvgPool2d((224,224))(self.mask)
        

        sample = {'img':self.img,'mask':self.mask}

        
            
        
        return sample
    def __len__(self):
    	return len(self.img_list)

In [0]:
def Tensor2Image(t):
    trans = transforms.ToPILImage()
    img = trans(t.squeeze())
    return img



class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        
        
        self.conv2a = nn.Conv2d(in_channels=3, out_channels=6,kernel_size=3,stride = 1)
        self.convtrans2a = nn.ConvTranspose2d(in_channels=6, out_channels=3,kernel_size=3,stride = 1)
        ## Here, we should define some smart layers
    def encode(self, dictionary):
        ## Use Deep NN to encode the image
        x = dictionary['img']
       
        h1 = nn.ReLU()(self.conv2a(x))
        return h1
    
    def decode(self, z):


        ## use the NN to decode to mask
        # batch_size = z.shape[0]
        h2 = nn.ReLU()(self.convtrans2a(z))
        
        
        
        return {'img': h2}

    def forward(self, dictionary):
        z = self.encode(dictionary)        
        poly_dict = self.decode(z)
        return poly_dict

In [0]:
transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                            std=[0.2, 0.2, 0.2])])   

whale_path = './data'


whale_data = Killer_Whale_Dataset(whale_path,transform = transform)
train_loader = torch.utils.data.DataLoader(whale_data,batch_size = 2,shuffle = True,drop_last = False)



net_test = AE().cuda()

## Need to figure out what loss and optimizer to use
loss_fn = torch.nn.MSELoss()
optimizer = optim.Adam(net_test.parameters(), lr=0.001)

In [0]:
from IPython import display
losses = []

fig=plt.figure(figsize=(20, 5), dpi= 80, facecolor='w', edgecolor='k')
axes=fig.subplots(1,4)

for epoch in range(50):
    iterator = iter(train_loader)
    for i in range(len(train_loader)):
        batch = next(iterator)
        dict_to_device(batch, device)
        preds = net_test(batch)
        loss = loss_fn(preds['img'], batch['mask'])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    for ax in axes:
        ax.cla() 

        axes[0].imshow(Tensor2Image(preds['img'].cpu()))
        axes[0].set_title('good to see?')
        axes[1].imshow(Tensor2Image(batch['mask'].cpu()))
        axes[1].set_title('ground truth')
        axes[3].plot(losses)
        axes[3].set_yscale('log')
        axes[3].set_xlabel("distance")
        axes[3].set_title('Training loss') 
        display.clear_output(wait=True)
        display.display(plt.gcf())
        #print("Plot after epoch {} (iteration {})".format(epoch, len(losses))) 

display.display(plt.gcf())