In [1]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
from models.WGAN_Model import WGAN

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time
from datamodules.ganDataset import ganDataset
from torch.utils.data import DataLoader

In [3]:
SEQ_LEN = 164
data_dir = "../Data/activity_summary_stats_and_metadata.txt"

In [4]:
pl.seed_everything(7)

[rank: 0] Global seed set to 7


7

In [5]:
wandb_logger = WandbLogger(
    project='BCLab-WGAN',
    name=time.strftime('%Y-%m-%d-%H-%M'),
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeter6866[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
dataset = ganDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=64, num_workers=4, shuffle=True)

In [7]:
model = WGAN(seq_len=SEQ_LEN)

In [8]:
trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator='gpu',
    devices=-1,
    max_epochs=100,
    deterministic=True
)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
torch.set_float32_matmul_precision('high')
trainer.fit(model, dataloader)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | generator | Generator | 2.2 M 
1 | critic    | Critic    | 517 K 
----------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.703    Total estimated model params size (MB)


Epoch 99: 100%|██████████| 191/191 [00:12<00:00, 15.16it/s, v_num=sbyu, d_loss=2.050, g_loss=-.165]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 191/191 [00:12<00:00, 15.11it/s, v_num=sbyu, d_loss=2.050, g_loss=-.165]


In [10]:
noise = torch.randn(100)

In [11]:
wandb.finish()



0,1
d_loss,▇█▆█▅▅▄▄▃▂▁▅▇▃▂▅▅▁▃▃▄▇▇▄▁▃▇▇▁▃▆▅▆▆▆▃█▃▂▄
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
g_loss,▂▅▆▂▂▄▃▄█▆▄▇▃▆▇▃▂▅▃▄▄▃▁▄▃▂▅▄▄▂▃▃▄▃▇█▅▂▇█
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
d_loss,2.05162
epoch,99.0
g_loss,-0.16546
trainer/global_step,19099.0


In [12]:
import numpy as np
import selene_sdk

In [13]:
sample = model(noise)
sample.shape

torch.Size([1, 4, 164])

In [14]:
# Decoding to DNA sequence
_, indices = torch.max(sample, dim=1)
bases = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
dna_sequence = ''.join(bases[i.item()] for i in indices[0])

print(dna_sequence)

TCAGCTGCCCGTTCTACTGACTAAAGCAGACATACGGCCGAGTCCGACCGGCGGTGGGTGGGTGCACGTTGACTAGCGCGGTGGAACAGTGACCGGGCTCCGTTTTAGAGAAATTCGAGGGGTCCCTGATGTGACCCGGTCGCAGGCGCGAAGTTGAAGTGGGG


In [15]:
noise = torch.randn(100)
sample = model(noise)
_, indices = torch.max(sample, dim=1)
bases = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
dna_sequence = ''.join(bases[i.item()] for i in indices[0])

print(dna_sequence)

GAATGTGACAAGCGGCTGACGAGCCTCAAAAGCGTCACGAAGACACGTATTACGGGGACAAGTTTGGGCCTGCGGAGGCAGACACGTCAGCGCGGGGAGCGCTGCCGTTACATGCTGACTAGGGCTCATTTGGCGGGAGCGAACGCTCGACAGACAGGTGAGCG


In [19]:
noise = torch.randn(100)
sample = model(noise)
_, indices = torch.max(sample, dim=1)
bases = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
dna_sequence = ''.join(bases[i.item()] for i in indices[0])

print(dna_sequence)

GCGGACCAAGGCAAGGCGGCTTTGTAGTAGAAGGAGAGTCGTGCGCGGCTGGCTAGAGTTTTCCGTATGGCAGTGCGGACCATGGTGGAATACAATTTGGGGGTCACGGGGGCAGTGCCGGAAGGGCGGAACAAGGCTTACGGTGCCTGCGAACATGGCGCTCA
