# Training script



In [2]:
!pip install albumentations
!pip install librosa
!pip install torch
!pip install numpy

Collecting albumentations
  Downloading albumentations-1.3.0-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.5/123.5 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting qudida>=0.0.4
  Downloading qudida-0.0.4-py3-none-any.whl (3.5 kB)
Collecting scikit-image>=0.16.1
  Downloading scikit_image-0.20.0-cp39-cp39-macosx_12_0_arm64.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m50.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting opencv-python-headless>=4.1.1
  Downloading opencv_python_headless-4.7.0.72-cp37-abi3-macosx_11_0_arm64.whl (32.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.6/32.6 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting imageio>=2.4.1
  Downloading imageio-2.27.0-py3-none-any.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m57.8 MB/s[0m eta [36m0:00:

In [3]:
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cv2
from pdb import set_trace
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision.utils import save_image

import soundfile
import sys
from tqdm import tqdm
import random

In [4]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
max_size_x = 1025
max_size_y = 650
DEVICE = 'cuda'
TRAIN_DIR = 'Images/train'
VAL_DIR = 'Images/val'
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 0
NUM_EPOCHS = 1
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_Trap = 'checkpoints/genh.pth.tar'
CHECKPOINT_GEN_Rock = 'checkpoints/genz.pth.tar'
CHECKPOINT_CRITIC_Trap = 'checkpoints/critich.pth.tar'
CHECKPOINT_CRITIC_Rock = 'checkpoints/criticz.pth.tar'

## LOSS ARRAYS 

D_Trap_real_loss_array = []
D_Trap_fake_loss_array = []
D_Trap_loss_array =  []

D_Rock_real_loss_array = []
D_Rock_fake_loss_array = []
D_Rock_loss_array = []

D_loss_array = []

loss_G_Trap_array = []
loss_G_Rock_array = []

cycle_rock_loss_array = []
cycle_trap_loss_array = []

G_loss_array = []


In [5]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [6]:
class TrapRockDataset(Dataset):
    def __init__(self, root_rock, root_trap, transform=None):
        self.root_rock = root_rock
        self.root_trap = root_trap
        self.transform = transform

        self.rock_images = os.listdir(root_rock)
        self.trap_images = os.listdir(root_trap)
        self.length_dataset = max(len(self.rock_images), len(self.trap_images)) # 1000, 1500
        self.rock_len = len(self.rock_images)
        self.trap_len = len(self.trap_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        rock_img = self.rock_images[index % self.rock_len]
        trap_img = self.trap_images[index % self.trap_len]

        rock_path = os.path.join(self.root_rock, rock_img)
        trap_path = os.path.join(self.root_trap, trap_img)

        rock_img = np.array(Image.open(rock_path))
        trap_img = np.array(Image.open(trap_path))

        if self.transform:
            #set_trace()
            augmentations = self.transform(image=rock_img, image0=trap_img)
            rock_img = augmentations["image"]
            trap_img = augmentations["image0"]
        # set_trace()
        return rock_img, trap_img

In [7]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                4,
                stride,
                1,
                bias=True,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


In [8]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,
                1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect",
            )
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        #set_trace()
        #x = x.unsqueeze(1)
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [9]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)



In [12]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                img_channels,
                num_features,
                kernel_size=7,
                stride=1,
                padding=3,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 4,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                ),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features * 4,
                    num_features * 2,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 1,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
            ]
        )

        self.last = nn.Conv2d(
            num_features * 1,
            img_channels,
            kernel_size=7,
            stride=1,
            padding=3,
            padding_mode="reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = checkpoint.checkpoint_sequential(self.res_blocks, segments=len(self.res_blocks), input=x) # Use checkpoints for residual blocks
        for layer in self.up_blocks:
            x = layer(x)
        #set_trace()
        return torch.tanh(self.last(x))[: , : ,:max_size_x, :(max_size_y)]


def train_fn(
    disc_Trap, disc_Rock, gen_Rock, gen_Trap, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler, epoch, 
    D_Trap_real_loss_sum, D_Trap_fake_loss_sum, D_Trap_loss_sum,D_Rock_real_loss_sum, D_Rock_fake_loss_sum, D_Rock_loss_sum, D_loss_sum, 
    loss_G_Trap_sum, loss_G_Rock_sum,cycle_rock_loss_sum,  cycle_trap_loss_sum, G_loss_sum
):
    Trap_reals = 0
    Trap_fakes = 0
    #set_trace()
    loop = tqdm(loader, leave=True)
    #set_trace()
    for idx, (rock, trap) in enumerate(loop):
        rock = rock.to(DEVICE, dtype=torch.float32)
        trap = trap.to(DEVICE, dtype=torch.float32)
        rock = rock.unsqueeze(1)
        trap = trap.unsqueeze(1)
        # Train Discriminators H and Z
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            #set_trace()
            fake_trap = gen_Trap(rock)
            D_Trap_real = disc_Trap(trap)
            D_Trap_fake = disc_Trap(fake_trap.detach())
            Trap_reals += D_Trap_real.mean().item()
            Trap_fakes += D_Trap_fake.mean().item()

            D_Trap_real_loss = mse(D_Trap_real, torch.ones_like(D_Trap_real))
            D_Trap_real_loss_sum += D_Trap_real_loss.detach().cpu().numpy().item()

            D_Trap_fake_loss = mse(D_Trap_fake, torch.zeros_like(D_Trap_fake))
            D_Trap_fake_loss_sum += D_Trap_fake_loss.detach().cpu().numpy().item()

            D_Trap_loss = D_Trap_real_loss + D_Trap_fake_loss
            D_Trap_loss_sum += D_Trap_loss.detach().cpu().numpy().item()


            fake_rock = gen_Rock(trap)
            D_Rock_real = disc_Rock(rock)
            D_Rock_fake = disc_Rock(fake_rock.detach())
            D_Rock_real_loss = mse(D_Rock_real, torch.ones_like(D_Rock_real))
            D_Rock_real_loss_sum += D_Rock_real_loss.detach().cpu().numpy().item()

            D_Rock_fake_loss = mse(D_Rock_fake, torch.zeros_like(D_Rock_fake))
            D_Rock_fake_loss_sum += D_Rock_fake_loss.detach().cpu().numpy().item()

            D_Rock_loss = D_Rock_real_loss + D_Rock_fake_loss
            D_Rock_loss_sum += D_Rock_loss.detach().cpu().numpy().item()



            # put it togethor
            D_loss = (D_Trap_loss + D_Rock_loss) / 2
            D_loss_sum += D_loss.detach().cpu().numpy().item()


        opt_disc.zero_grad(set_to_none=True)
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            # adversarial loss for both generators
            D_Trap_fake = disc_Trap(fake_trap)
            D_Rock_fake = disc_Rock(fake_rock)
            loss_G_Trap = mse(D_Trap_fake, torch.ones_like(D_Trap_fake))
            loss_G_Trap_sum += loss_G_Trap.detach().cpu().numpy().item()


            loss_G_Rock = mse(D_Rock_fake, torch.ones_like(D_Rock_fake))
            loss_G_Rock_sum += loss_G_Rock.detach().cpu().numpy().item()


            # set_trace()
            # cycle loss
            cycle_rock = gen_Rock(fake_trap)
            cycle_trap = gen_Trap(fake_rock)
            cycle_rock_loss = l1(rock, cycle_rock)
            cycle_rock_loss_sum += cycle_rock_loss.detach().cpu().numpy().item()

            cycle_trap_loss = l1(trap, cycle_trap)
            cycle_trap_loss_sum += cycle_trap_loss.detach().cpu().numpy().item()



            # add all togethor
            G_loss = (
                loss_G_Rock
                + loss_G_Trap
                + cycle_rock_loss * LAMBDA_CYCLE
                + cycle_trap_loss * LAMBDA_CYCLE
            )
            G_loss_sum += G_loss.detach().cpu().numpy().item()


        opt_gen.zero_grad(set_to_none=True)
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            save_image(fake_trap * 0.5 + 0.5, f"Fake images/trap_{idx}.png")
            save_image(fake_rock * 0.5 + 0.5, f"Fake images/rock_{idx}.png")
            #set_trace()
            hop_length = 512
            n_fft = 2048
            sr = 22500

            #trap audio
            spectrogram = cv2.imread(f"Fake images/trap_{idx}.png", cv2.IMREAD_GRAYSCALE).astype('float32')
            # Normalize the spectrogram to the range [0, 1]
            spectrogram /= 255.0
            # Reconstruct the audio waveform from the spectrogram
            audio = librosa.griffinlim(spectrogram, hop_length=hop_length, win_length=n_fft)
            file_name_full = f"Fake Audio/trap_{idx}_EPOCH{epoch}.wav"
            soundfile.write(file_name_full, audio, samplerate=sr)

            #rock audio
            spectrogram = cv2.imread(f"Fake images/rock_{idx}.png", cv2.IMREAD_GRAYSCALE).astype('float32')
            # Normalize the spectrogram to the range [0, 1]
            spectrogram /= 255.0
            # Reconstruct the audio waveform from the spectrogram
            audio = librosa.griffinlim(spectrogram, hop_length=hop_length, win_length=n_fft)
            file_name_full = f"Fake Audio/rock_{idx}_EPOCH{epoch}.wav"
            soundfile.write(file_name_full, audio, samplerate=sr)


        loop.set_postfix(Trap_real=Trap_reals / (idx + 1), Trap_fake=Trap_fakes / (idx + 1))
    return D_Trap_real_loss_sum,D_Trap_fake_loss_sum,D_Trap_loss_sum,D_Rock_real_loss_sum, D_Rock_fake_loss_sum, D_Rock_loss_sum,D_loss_sum, loss_G_Trap_sum, loss_G_Rock_sum,cycle_rock_loss_sum, cycle_trap_loss_sum, G_loss_sum


In [13]:
def main():
  
    disc_Trap = Discriminator(in_channels=1).to(DEVICE)
    disc_Rock = Discriminator(in_channels=1).to(DEVICE)

    gen_Rock = Generator(img_channels=1, num_residuals=9).to(DEVICE)
    gen_Trap = Generator(img_channels=1, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_Trap.parameters()) + list(disc_Rock.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Rock.parameters()) + list(gen_Trap.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_Trap,
            gen_Trap,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_Rock,
            gen_Rock,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_Trap,
            disc_Trap,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_Rock,
            disc_Rock,
            opt_disc,
            LEARNING_RATE,
        )

    dataset = TrapRockDataset(
        root_trap=TRAIN_DIR + "/trap",
        root_rock=TRAIN_DIR + "/rock",
        # transform=transforms,
    )
    val_dataset = TrapRockDataset(
        root_trap=VAL_DIR + "/trap",
        root_rock=VAL_DIR + "/rock",
        # transform=transforms,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):

        D_Trap_real_loss_sum = 0
        D_Trap_fake_loss_sum = 0
        D_Trap_loss_sum = 0
        D_Rock_real_loss_sum = 0
        D_Rock_fake_loss_sum = 0
        D_Rock_loss_sum = 0
        D_loss_sum = 0
        loss_G_Trap_sum = 0
        loss_G_Rock_sum = 0
        cycle_rock_loss_sum = 0
        cycle_trap_loss_sum = 0
        G_loss_sum = 0

        D_Trap_real_loss_sum,D_Trap_fake_loss_sum,D_Trap_loss_sum,D_Rock_real_loss_sum, D_Rock_fake_loss_sum, D_Rock_loss_sum,D_loss_sum, loss_G_Trap_sum, loss_G_Rock_sum,cycle_rock_loss_sum, cycle_trap_loss_sum, G_loss_sum=  train_fn(
            disc_Trap,
            disc_Rock,
            gen_Rock,
            gen_Trap,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
            epoch,
            D_Trap_real_loss_sum,
            D_Trap_fake_loss_sum,
            D_Trap_loss_sum,
            D_Rock_real_loss_sum, 
            D_Rock_fake_loss_sum, 
            D_Rock_loss_sum,
            D_loss_sum, 
            loss_G_Trap_sum, 
            loss_G_Rock_sum,
            cycle_rock_loss_sum, 
            cycle_trap_loss_sum, 
            G_loss_sum
        )

        if SAVE_MODEL:
            save_checkpoint(gen_Trap, opt_gen, filename=CHECKPOINT_GEN_Trap)
            save_checkpoint(gen_Rock, opt_gen, filename=CHECKPOINT_GEN_Rock)
            save_checkpoint(disc_Trap, opt_disc, filename=CHECKPOINT_CRITIC_Trap)
            save_checkpoint(disc_Rock, opt_disc, filename=CHECKPOINT_CRITIC_Rock)

        D_Trap_real_loss_array.append(D_Trap_real_loss_sum)
        D_Trap_fake_loss_array.append(D_Trap_fake_loss_sum)
        D_Trap_loss_array.append(D_Trap_loss_sum)
        D_Rock_real_loss_array.append(D_Rock_real_loss_sum)
        D_Rock_fake_loss_array.append(D_Rock_fake_loss_sum)
        D_Rock_loss_array.append(D_Rock_loss_sum)
        D_loss_array.append(D_loss_sum)
        loss_G_Trap_array.append(loss_G_Trap_sum)
        loss_G_Rock_array.append(loss_G_Rock_sum)
        cycle_rock_loss_array.append(cycle_rock_loss_sum)
        cycle_trap_loss_array.append(cycle_trap_loss_sum)
        G_loss_array.append(G_loss_sum)

        np.savetxt('D_Trap_real_loss.csv', D_Trap_real_loss_array, delimiter=',')
        np.savetxt('D_Trap_fake_loss.csv', D_Trap_fake_loss_array, delimiter=',')
        np.savetxt('D_Trap_loss.csv', D_Trap_loss_array, delimiter=',')
        np.savetxt('D_Rock_real_loss.csv', D_Rock_real_loss_array, delimiter=',')
        np.savetxt('D_Rock_fake_loss.csv', D_Rock_fake_loss_array, delimiter=',')
        np.savetxt('D_Rock_loss.csv', D_Rock_loss_array, delimiter=',')
        np.savetxt('D_loss.csv', D_loss_array, delimiter=',')
        np.savetxt('loss_G_Trap.csv', loss_G_Trap_array, delimiter=',')
        np.savetxt('loss_G_Rock.csv', loss_G_Rock_array, delimiter=',')
        np.savetxt('cycle_rock_loss.csv', cycle_rock_loss_array, delimiter=',')
        np.savetxt('cycle_trap_loss.csv', cycle_trap_loss_array, delimiter=',')
        np.savetxt('G_loss.csv', G_loss_array, delimiter=',')


if __name__ == "__main__":
    main()

AssertionError: Torch not compiled with CUDA enabled