### Google Drive requirements:
1. `CS-COCO.zip` dataset file
2. `adain` package and `train.py` module
3. `content.jpg` and `style.jpg` images (must be the same size)
4. `vgg19-norm.pth` file with VGG19 weights
5. `checkpoint-<iter_num>.pth` (if you want to start training from checkpoint)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [None]:
import os
import sys

path2drive = '/content/gdrive/My Drive'
output = os.path.join(path2drive, 'output')

sys.path.append(path2drive)

In [None]:
import zipfile

path2ds = '/content/CS-COCO'
zip_file = os.path.join(path2drive, 'CS-COCO.zip')

archive = zipfile.ZipFile(zip_file, 'r')
archive.extractall(path2ds)

In [None]:
import shutil
from utils.constants import CHECKPOINTS_DIR, PATH2VGG, LOGS_DIR

if not os.path.exists(output):
    os.mkdir(output)

if not os.path.exists('extra'):
    os.mkdir('extra')

    shutil.copy(os.path.join(path2drive, 'content.jpg'), '/content/extra/content.jpg')
    shutil.copy(os.path.join(path2drive, 'style.jpg'), '/content/extra/style.jpg')

if not os.path.exists(CHECKPOINTS_DIR):
    os.makedirs(CHECKPOINTS_DIR)
    
    shutil.copy(os.path.join(path2drive, 'vgg19-norm.pth'), f'/content/{PATH2VGG}')

if not os.path.exists(LOGS_DIR):
    os.makedirs(LOGS_DIR)

In [None]:
from utils import *
from utils.utils import save_model, save_log
from train import train

import torch

model = StyleTransferNetwork().to(DEVICE)

start = 0

for filename in os.listdir(path2drive):
    if filename.startswith('checkpoint'):

        path = os.path.join(path2drive, filename)
        state_dict = torch.load(path, map_location=DEVICE)

        model.load_state_dict(state_dict)
        start = int(filename[11:-4])

        print('Successful uploading')
        break

for iter_num in train(model, path2ds, start):
    path2model = save_model(model, iter_num)
    shutil.copy(path2model, os.path.join(output, os.path.basename(path2model)))

    path2log = save_log(model, iter_num)
    shutil.copy(path2log, os.path.join(output, os.path.basename(path2log)))

Successful uploading


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Epoch: 15; samples: 100%|██████████| 28000/28000 [1:16:00<00:00,  6.14 sample/s]


Mean content loss: 1.3570023775100708
Mean style loss: 0.37971818447113037
Mean common loss: 5.154184341430664


Epoch: 16; samples:  65%|██████▍   | 51760/80000 [2:20:56<1:20:09,  5.87 sample/s]