In [1]:
# Copyright (c) Xi Chen
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Borrowed from https://github.com/neocxi/pixelsnail-public and ported it to PyTorch

In [1]:
import gc
import importlib 
from math import sqrt
from functools import partial, lru_cache

import pandas as pd
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

from vq_vae_2 import Model
from pixel_snail import PixelSNAIL
from scheduler import CycleScheduler

import visdom
import matplotlib.pyplot as plt

In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class MelData(Dataset):
    def __init__(self, file_path):
        self.data = []
        genre = ['classical', 'rock', 'electronic', 'pop']
        
        for g in genre:
            for i in range(1, 101):
                for j in range(5):
                    tmp_path = f'{file_path}/{g}/{i}-{j}.csv'
                    try:
                        self.data.append((pd.read_csv(tmp_path), g, i, j))
                    except FileNotFoundError:
                        print(f"{g}-{i}-{j} file is deleted")
                        continue
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        mel, g, i, j = self.data[idx]
        mel = torch.from_numpy(pd.get_dummies(mel).values)
        mel = mel.type(torch.cuda.FloatTensor)
        
        return (g, i, j), mel

class EmotionalData(Dataset):
    def __init__(self, file_path):
        self.data = pd.read_csv(file_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tmp_data = self.data.iloc[idx]
        genre, idx = tmp_data[0].split('_')
        emo = tmp_data[1:]
        return idx, genre, torch.FloatTensor(emo)

In [4]:
batch_size = 32

In [5]:
EMO_PATH = "./mean_data.csv"
MEL_ARR_PATH = "./split_mel_array"
SAVE_PATH = "./save_models"
    
mel_arr_data = MelData(MEL_ARR_PATH)
#emo_data = EmotionalData(EMO_PATH)


classical-16-0 file is deleted
classical-16-1 file is deleted
classical-16-2 file is deleted
classical-16-3 file is deleted
classical-16-4 file is deleted
classical-40-0 file is deleted
classical-40-1 file is deleted
classical-40-2 file is deleted
classical-40-3 file is deleted
classical-40-4 file is deleted
classical-57-0 file is deleted
classical-57-1 file is deleted
classical-57-2 file is deleted
classical-57-3 file is deleted
classical-57-4 file is deleted
classical-66-0 file is deleted
classical-66-1 file is deleted
classical-66-2 file is deleted
classical-66-3 file is deleted
classical-66-4 file is deleted
classical-73-0 file is deleted
classical-73-1 file is deleted
classical-73-2 file is deleted
classical-73-3 file is deleted
classical-73-4 file is deleted


In [None]:
mel_arr_data_loader = DataLoader(
        dataset=mel_arr_data, batch_size=batch_size)

#emo_data_loader = DataLoader(
#        dataset=emo_data, batch_size=batch_size, shuffle=True)

In [None]:
def scaled(x):
    return x + 80.0
def unscaled(x):
    return x - 80.0

In [None]:
# extract idices
# torch.Size([32, 10, 128]) torch.Size([32, 20, 256])
def extract_indice(mel_data, model):
    with torch.no_grad():
        for _, mel in mel_data:
            x = scaled(mel)
            x = x[:, :, :-4].unsqueeze(1).to(device)
            _, _, _, _, id_t, id_b = model.encode(x)
            try:
                ids_t = torch.cat([ids_t, id_t], dim=0)
                ids_b = torch.cat([ids_b, id_b], dim=0)
            except Exception as e:
                print(e)
                ids_t, ids_b = id_t.clone().detach(), id_b.clone().detach()
    return ids_t, ids_b
            

In [None]:
num_hiddens = 128 #128
num_residual_hiddens = 32
num_residual_layers = 4
embedding_dim = 64 #64
num_embeddings = 512 #512
commitment_cost = 0.25

In [None]:
import vq_vae_2
importlib.reload(vq_vae_2)
from vq_vae_2 import Model

In [None]:
model = Model(num_hiddens=num_hiddens, 
                  num_residual_layers=num_residual_layers,
                  num_residual_hiddens=num_residual_hiddens,
                  num_embeddings=num_embeddings,
                  embedding_dim=embedding_dim, 
                  commitment_cost=commitment_cost).to(device)

score = 119.51515197753906
MODEL_PATH = f'{SAVE_PATH}/vqvae2_light-{score:.5f}_dict.pt'
model.load_state_dict(torch.load(MODEL_PATH))

<All keys matched successfully>

In [None]:
ids_t, ids_b = extract_indice(mel_arr_data_loader, model)
del mel_arr_data_loader, mel_arr_data

local variable 'ids_t' referenced before assignment


In [None]:
print(ids_t.size(), ids_b.size())

torch.Size([1975, 5, 64]) torch.Size([1975, 10, 128])


In [None]:
gc.collect()

63

In [None]:
def train(hier, epoch, loader, model, optimizer, scheduler, device, batch_size):
    
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()
    
    for i, (top, bottom) in enumerate(loader):
        model.zero_grad()

        top = top.to(device)
        # print(top.shape, bottom.shape)

        if hier == 'top':
            target = top
            out, _ = model(top)

        elif hier == 'bottom':
            target = bottom.to(device) # bottom
            out, _ = model(target, condition=top)
            out = out

        loss = criterion(out, target)
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}\t loss: {loss.item():.5f}\t '
                f'acc: {accuracy:.5f}\t lr: {lr:.5f}'
            )
        )

        # visualization
        vis.line(X=torch.Tensor([i + (1968 / batch_size) * epoch]),
        Y=torch.Tensor([loss.item()]),
        win = plt,
        update='append'
        )

class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()

In [None]:
batch = 4
epoch = 150
lr = 1e-4 # 1e-4 bottom
channel = 512
n_res_block = 4
n_res_channel = 256
n_out_res_block = 0
n_cond_res_block = 3
dropout = 0.1

In [None]:
class IDsData(Dataset):
    def __init__(self, ids_t, ids_b):
        self.ids_t = ids_t
        self.ids_b = ids_b
        
    def __len__(self):
        return len(self.ids_t)
    
    def __getitem__(self, idx):
        return self.ids_t[idx], self.ids_b[idx]

In [None]:
ids_data = IDsData(ids_t, ids_b)

In [None]:
ids_loader = DataLoader(
        ids_data, batch_size=batch, shuffle=True, drop_last=True
    )

In [20]:
import importlib, pixel_snail
importlib.reload(pixel_snail)
from pixel_snail import PixelSNAIL

In [None]:
model_top = PixelSNAIL(
    [5, 64], # [32, 40],
    512,
    channel,
    5,
    4,
    n_res_block,
    n_res_channel,
    dropout=dropout,
    n_out_res_block=n_out_res_block,
)


model_bottom = PixelSNAIL(
    [10, 128], # [20, 256]
    512,
    channel,
    5,
    4,
    n_res_block,
    n_res_channel,
    attention=False,
    dropout=dropout,
    n_cond_res_block=n_cond_res_block,
    cond_res_channel=n_res_channel,
)

In [None]:
model_top = model_top.to(device)
optimizer_top = optim.Adam(model_top.parameters(), lr=lr)

model_bottom = model_bottom.to(device)
optimizer_bottom = optim.Adam(model_bottom.parameters(), lr=lr)

#scheduler_top = CycleScheduler(optimizer_top, lr, n_iter=len(ids_t_loader) * epoch, momentum=None)
#scheduler_bottom = CycleScheduler(optimizer_bottom, lr, n_iter=len(ids_b_loader) * epoch, momentum=None)

In [23]:
del model_top

In [24]:
gc.collect()
torch.cuda.empty_cache()

In [25]:
vis = visdom.Visdom()

Setting up a new session...


In [27]:
vis.close(env="main")
plt = vis.line(Y=torch.Tensor(1).zero_())
for i in range(epoch):
    loss = train('top', i, ids_loader, model_top, optimizer_top, None, device, batch)
    if i % 100 == 0:
        torch.save(model_top.state_dict(), f'{SAVE_PATH}/pixelsnail_ckp/pixelsnail_top_{str(i).zfill(3)}.pt')
        print(f'Save pixelsnail_top_{str(i).zfill(3)}.pt')

  0%|          | 0/493 [00:00<?, ?it/s]

Save pixelsnail_top_000.pt


  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

Save pixelsnail_top_100.pt


  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

In [26]:
torch.save(model_top.state_dict(), f'{SAVE_PATH}/pixelsnail_ckp/pixelsnail_top_{str(63).zfill(3)}.pt')

NameError: name 'model_top' is not defined

In [27]:
vis.close(env="main")
plt = vis.line(Y=torch.Tensor(1).zero_())
for i in tqdm(range(epoch)):
    train('bottom', i, ids_loader, model_bottom, optimizer_bottom, None, device, batch)
    if i % 50 == 0:
        torch.save(model_bottom.state_dict(), f'{SAVE_PATH}/pixelsnail_ckp/pixelsnail_bot_{str(i).zfill(3)}.pt')
        print(f'Save pixelsnail_bot_{str(i).zfill(3)}.pt')

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

Save pixelsnail_bot_000.pt


  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

Save pixelsnail_bot_050.pt


  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

Save pixelsnail_bot_100.pt


  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

  0%|          | 0/493 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [24]:
torch.save(model_bottom.state_dict(), f'{SAVE_PATH}/pixelsnail_ckp/pixelsnail_bot_130.pt')