In [1]:
import scipy.io
import os
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import torch
import pytorch_lightning as ptl
from weatherGan.models.WeatherGan import WeatherGan
from weatherGan.dataloader.dataset import ImageDataset
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import Callback
import random



In [2]:
#path of the dataset
PATH = 'datasets/dataset/'

In [3]:
#load the dataset from our folders, we use a custom class implemented in the library WeatherGan
train_dataset = ImageDataset(data_dir=PATH,transform= lambda x : x/255,target_transform= lambda x : torch.nn.functional.one_hot( torch.tensor(int(x)), 2)
)
val_dataset = ImageDataset(data_dir=PATH,transform= lambda x : x/255,target_transform= lambda x : torch.nn.functional.one_hot( torch.tensor(int(x)), 2), mode='test'
)

  0%|          | 0/8000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [4]:
#loading the dataset in our train and val loaders
train_loader = DataLoader(train_dataset, batch_size = 4,num_workers=8,shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 4,num_workers=8,shuffle = False)

In [5]:
#This is our Gan, we already pretrained the segmentator and the weather classifier
#we freeze all weights except those of the generator and the expert
model = WeatherGan(num_classes_segmentation=6,num_classes_weather=2,noise_size=32,
         weather_discr_ckpt='checkpoints/weatherClassifier/epoch=9-step=2499.ckpt',segmentator_ckpt='checkpoints/segmentator/last_seg_ckpt.ckpt' ).double()


In [6]:
#defining a checkpoint call back to save our weights at each epoch end, and a logger to track the metrics of our model during the training
checkpoint_callback = ModelCheckpoint(dirpath='checkpoints/gancheckpoints/',monitor='val_loss',mode='min')
wandbLogger = WandbLogger(project='weatherGan',name='GAN_training_fixed')

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


In [7]:
#This is a callback we defined to generate and save k images at each epoch end.
#Sadly we didn't had the time to test it properly so it's not used

class ImageGenCallback(Callback):
    def __init__(self,num_gen,path,save_path):
        super(ImageGenCallback,self).__init__()
        self.path = path
        self.save_path = save_path
        self.num_gen = num_gen
        self.buffer = 0
    def on_batch_end(self,trainer,pl_module):
        self.buffer += 1
        if self.buffer % 100 == 0:
            img, labels, targets, noises = self.__load_random_images__()
            x_gn, discr, weather = trainer.model(img,targets,noises)
            x_gn, discr, weather = x_gn.detach(), discr.detach(), weather.detach()
            torch.save(img,f'{self.save_path}/true_image_{self.buffer}_n.pt')
            torch.save(x_gn,f'{self.save_path}/image_gen_{self.buffer}_n.pt')
            torch.save(discr,f'{self.save_path}/discr_{self.buffer}_n.pt')
            torch.save(weather,f'{self.save_path}/weather_{self.buffer}_n.pt')
    def __load_random_images__(self):
        path_img = self.path+f'test_images/'
        images = [img for img in os.listdir(path_img)]
        imgs, labels, targets, noises = [], [], [], []
        choices = random.choices(range(0,len(images)),k=self.num_gen)
        for choice in choices:
            img_path = path_img + images[choice]
            img_array = np.array(Image.open(img_path))
            imgs.append(torch.tensor(img_array).double().to(1))
            label = torch.tensor(int(np.array([ 'sunny' in images[choice] ])) ).to(1)
            target = 1-label
            labels.append(torch.nn.functional.one_hot(label , 2) )
            targets.append(torch.nn.functional.one_hot( target,2 ) )
        labels = torch.stack(labels,dim=0)
        targets = torch.stack(targets,dim=0)
        imgs = torch.stack(imgs,dim=0).permute(0,3,1,2)/255
        noises = torch.randn(size=(self.num_gen,32),device =1)

        return imgs, labels, targets, noises
        
    def on_train_end(self, trainer, pl_module):
        print("do something when training ends")

In [8]:
#a classical ptl trainer, we specify the name of the used device. If you don't have access to a GPU please set device to 'cpu' or None
device = [1]
image_callback = ImageGenCallback(4,PATH,'image_callbacks')
trainer = ptl.Trainer(gpus=device,max_epochs=20,
                      callbacks=[checkpoint_callback,image_callback],
                      logger=wandbLogger, log_every_n_steps=5 
                     )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
def print_image(tensor):
    np_img = tensor.detach().permute(1,2,0).cpu().numpy()
    plt.imshow(np_img )
    

In [10]:
for param in model.expert.featureExtractor.parameters():
    param.requires_grad = False

In [11]:
#Just making sure that the only trainable parameters are those of the expert and the generator
for name, param in model.named_parameters():
    if param.requires_grad == True:
        print(name)

generator.Ginit.downSampleBlock.convBlock1.conv.weight
generator.Ginit.downSampleBlock.convBlock1.conv.bias
generator.Ginit.downSampleBlock.convBlock1.layernorm.weight
generator.Ginit.downSampleBlock.convBlock1.layernorm.bias
generator.Ginit.downSampleBlock.convBlock2.conv.weight
generator.Ginit.downSampleBlock.convBlock2.conv.bias
generator.Ginit.downSampleBlock.convBlock2.layernorm.weight
generator.Ginit.downSampleBlock.convBlock2.layernorm.bias
generator.Ginit.downSampleBlock.convBlock3.conv.weight
generator.Ginit.downSampleBlock.convBlock3.conv.bias
generator.Ginit.downSampleBlock.convBlock3.layernorm.weight
generator.Ginit.downSampleBlock.convBlock3.layernorm.bias
generator.Ginit.upSampleBlock.convT1.tranpose_conv.weight
generator.Ginit.upSampleBlock.convT1.tranpose_conv.bias
generator.Ginit.upSampleBlock.convT2.tranpose_conv.weight
generator.Ginit.upSampleBlock.convT2.tranpose_conv.bias
generator.Ginit.upSampleBlock.convT3.tranpose_conv.weight
generator.Ginit.upSampleBlock.convT3

In [12]:
#fitting the model
trainer.fit(model, train_loader,val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[34m[1mwandb[0m: Currently logged in as: [33madaminho[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name                 | Type                 | Params
--------------------------------------------------------------
0 | generator            | Gwithoutatt          | 4.2 M 
1 | expert               | Expert               | 190 K 
2 | weatherDiscriminator | WeatherDiscriminator | 190 K 
3 | ce                   | CrossEntropyLoss     | 0     
4 | l1                   | L1Loss               | 0     
--------------------------------------------------------------
674 K     Trainable params
4.0 M     Non-trainable params
4.6 M     Total params
18.500    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Training: -1it [00:00, ?it/s]

Process wandb_internal:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/internal/internal.py", line 153, in wandb_internal
    thread.join()
  File "/opt/conda/lib/python3.7/threading.py", line 1044, in join
    self._wait_for_tstate_lock()
  File "/opt/conda/lib/python3.7/threading.py", line 1060, in _wait_for_tstate_lock
    elif lock.acquire(block, timeout):
KeyboardInterrupt
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.7/multiprocessing/spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "/opt/conda/lib/python3.7/multiprocessing/spawn.py", line 118, in _main
    return self._bootstrap()
  File "/opt/conda/lib/python3.7/multiprocessing/process.p

Error in callback <function _WandbInit._pause_backend at 0x7f7f615a90e0> (for post_run_cell):


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Exception: The wandb backend process has shutdown