In [1]:
import os
import dataclasses
import numpy as np
import math
import pandas as pd

import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR

from tqdm import tqdm
import pickle

from torchmetrics.image.inception import InceptionScore
import matplotlib.pyplot as plt
import py3Dmol

In [2]:
atom_types = [
    'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
    'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
    'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
    'CZ3', 'NZ', 'OXT'
]

restypes = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
    'S', 'T', 'W', 'Y', 'V'
]
restype_order = {restype: i for i, restype in enumerate(restypes)}
restype_num = len(restypes)  # := 20.

restype_1to3 = {
    'A': 'ALA',
    'R': 'ARG',
    'N': 'ASN',
    'D': 'ASP',
    'C': 'CYS',
    'Q': 'GLN',
    'E': 'GLU',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'L': 'LEU',
    'K': 'LYS',
    'M': 'MET',
    'F': 'PHE',
    'P': 'PRO',
    'S': 'SER',
    'T': 'THR',
    'W': 'TRP',
    'Y': 'TYR',
    'V': 'VAL',
}

def make_np_example(coords_dict):
    """Make a dictionary of non-batched numpy protein features."""
    bb_atom_types = ['N', 'CA', 'C', 'O']
    bb_idx = [i for i, atom_type in enumerate(atom_types)
              if atom_type in bb_atom_types]

    num_res = np.array(coords_dict['N']).shape[0]
    atom_positions = np.zeros([num_res, 37, 3], dtype=float)

    for i, atom_type in enumerate(atom_types):
        if atom_type in bb_atom_types:
            atom_positions[:, i, :] = np.array(coords_dict[atom_type])

    # Mask nan / None coordinates.
    nan_pos = np.isnan(atom_positions)[..., 0]
    atom_positions[nan_pos] = 0.
    atom_mask = np.zeros([num_res, 37])
    atom_mask[..., bb_idx] = 1
    atom_mask[nan_pos] = 0

    batch = {
        'atom_positions': atom_positions,
        'atom_mask': atom_mask,
        'residue_index': np.arange(num_res)
    }
    return batch


def make_fixed_size(np_example, max_seq_length=500):
    """Pad features to fixed sequence length, i.e. currently axis=0."""
    for k, v in np_example.items():
        pad = max_seq_length - v.shape[0]
        if pad > 0:
            v = np.pad(v, ((0, pad),) + ((0, 0),) * (len(v.shape) - 1))
        elif pad < 0:
            v = v[:max_seq_length]
        np_example[k] = v


def center_positions(np_example):
  """Center 'atom_positions' on CA center of mass."""
  atom_positions = np_example['atom_positions']
  atom_mask = np_example['atom_mask']
  ca_positions = atom_positions[:, 1, :]
  ca_mask = atom_mask[:, 1]

  ca_center = (np.sum(ca_mask[..., None] * ca_positions, axis=0) /
   (np.sum(ca_mask, axis=0) + 1e-9))
  atom_positions = ((atom_positions - ca_center[None, ...]) *
                    atom_mask[..., None])
  np_example['atom_positions'] = atom_positions


class DatasetFromDataframe(torch.utils.data.Dataset):
    """Load coordinates data from a DataFrame, currently from the 'coords' column."""

    def __init__(self, data_frame, max_seq_length=512):
        self.data = data_frame
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        coords_dict = self.data.iloc[idx].coords
        np_example = make_np_example(coords_dict)
        make_fixed_size(np_example, self.max_seq_length)
        center_positions(np_example)
        example = {k: torch.tensor(v, dtype=torch.float32) for k, v
                   in np_example.items()}
        return example


# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS)  # := 62.


@dataclasses.dataclass(frozen=True)
class Protein:
  """Protein structure representation."""

  # Cartesian coordinates of atoms in angstroms. The atom types correspond to
  # residue_constants.atom_types, i.e. the first three are N, CA, CB.
  atom_positions: np.ndarray  # [num_res, num_atom_type, 3]

  # Amino-acid type for each residue represented as an integer between 0 and
  # 20, where 20 is 'X'.
  aatype: np.ndarray  # [num_res]

  # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
  # is present and 0.0 if not. This should be used for loss masking.
  atom_mask: np.ndarray  # [num_res, num_atom_type]

  # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
  residue_index: np.ndarray  # [num_res]

  # 0-indexed number corresponding to the chain in the protein that this residue
  # belongs to.
  chain_index: np.ndarray  # [num_res]

  # B-factors, or temperature factors, of each residue (in sq. angstroms units),
  # representing the displacement of the residue from its ground truth mean
  # value.
  b_factors: np.ndarray  # [num_res, num_atom_type]


def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
  chain_end = 'TER'
  return (f'{chain_end:<6}{atom_index:>5}      {end_resname:>3} '
          f'{chain_name:>1}{residue_index:>4}')


def to_pdb(prot: Protein) -> str:
  """Converts a `Protein` instance to a PDB string.

  Args:
    prot: The protein to convert to PDB.

  Returns:
    PDB string.
  """
  restypes = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
    'S', 'T', 'W', 'Y', 'V', 'X']
  res_1to3 = lambda r: restype_1to3.get(restypes[r], 'UNK')

  pdb_lines = []

  atom_mask = prot.atom_mask
  aatype = prot.aatype
  atom_positions = prot.atom_positions
  residue_index = prot.residue_index.astype(np.int32)
  chain_index = prot.chain_index.astype(np.int32)
  b_factors = prot.b_factors

  if np.any(aatype > restype_num):
    raise ValueError('Invalid aatypes.')

  # Construct a mapping from chain integer indices to chain ID strings.
  chain_ids = {}
  for i in np.unique(chain_index):  # np.unique gives sorted output.
    if i >= PDB_MAX_CHAINS:
      raise ValueError(
          f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
    chain_ids[i] = PDB_CHAIN_IDS[i]

  pdb_lines.append('MODEL     1')
  atom_index = 1
  last_chain_index = chain_index[0]
  # Add all atom sites.
  for i in range(aatype.shape[0]):
    # Close the previous chain if in a multichain PDB.
    if last_chain_index != chain_index[i]:
      pdb_lines.append(_chain_end(
          atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
          residue_index[i - 1]))
      last_chain_index = chain_index[i]
      atom_index += 1  # Atom index increases at the TER symbol.

    res_name_3 = res_1to3(aatype[i])
    for atom_name, pos, mask, b_factor in zip(
        atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
      if mask < 0.5:
        continue

      record_type = 'ATOM'
      name = atom_name if len(atom_name) == 4 else f' {atom_name}'
      alt_loc = ''
      insertion_code = ''
      occupancy = 1.00
      element = atom_name[0]  # Protein supports only C, N, O, S, this works.
      charge = ''
      # PDB is a columnar format, every space matters here!
      atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
                   f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
                   f'{residue_index[i]:>4}{insertion_code:>1}   '
                   f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
                   f'{occupancy:>6.2f}{b_factor:>6.2f}          '
                   f'{element:>2}{charge:>2}')
      pdb_lines.append(atom_line)
      atom_index += 1

  # Close the final chain.
  pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
                              chain_ids[chain_index[-1]], residue_index[-1]))
  pdb_lines.append('ENDMDL')
  pdb_lines.append('END')

  # Pad all lines to 80 characters.
  pdb_lines = [line.ljust(80) for line in pdb_lines]
  return '\n'.join(pdb_lines) + '\n'  # Add terminating newline.

In [3]:
print('Reading chain_set.jsonl, this can take 1 or 2 minutes...')
df = pd.read_json('/home/jovyan/protein_diffusion/data/chain_set.jsonl', lines=True)
cath_splits = pd.read_json('/home/jovyan/protein_diffusion/data/chain_set_splits.json', lines=True)
print('Read data.')

def get_split(pdb_name):
  if pdb_name in cath_splits.train[0]:
    return 'train'
  elif pdb_name in cath_splits.validation[0]:
    return 'validation'
  elif pdb_name in cath_splits.test[0]:
    return 'test'
  else:
    return 'None'

df['split'] = df.name.apply(lambda x: get_split(x))
df['seq_len'] = df.seq.apply(lambda x: len(x))

Reading chain_set.jsonl, this can take 1 or 2 minutes...
Read data.


In [4]:
class DatasetFromDataframe(torch.utils.data.Dataset):
    """Load coordinates data from a DataFrame, currently from the 'coords' column."""

    def __init__(self, data_frame, max_seq_length=512):
        self.data = data_frame
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        coords_dict = self.data.iloc[idx].coords
        np_example = make_np_example(coords_dict)
        make_fixed_size(np_example, self.max_seq_length)
        center_positions(np_example)
        example = {k: torch.tensor(v, dtype=torch.float32) for k, v
                   in np_example.items()}

        return example

In [5]:
class Config():
  def __init__(self):
    self.max_seq_length = 512
    self.batch_size = 128

cfg = Config()

train_set = DatasetFromDataframe(df[df.split == 'train'],
                                 max_seq_length=cfg.max_seq_length)
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=cfg.batch_size,
                                           shuffle=False)
train_iter = iter(train_loader)

In [6]:
class UNet(nn.Module):
    """
    A minimal implementation of the Unet architecture.

    Encoder: comprises 3 convolutional layers,
    each succeeded with a ReLU activation.

    Decoder: comprises 3 convolutional layers,
    first two succeeded with a ReLU activation.

    self.time_mlp: embeds the time signal, linearly projecting it
    into a 256-dimensional representation.

    self.batch_normalization: normalizes input batches.
    """
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=256):
        super().__init__()

        # time embedding block
        self.time_mlp = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, time_emb_dim),
            nn.ReLU()
        )

        # simple encoder block
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )

        # simple decoder block
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1)
        )
        self.batch_norm = nn.BatchNorm2d(time_emb_dim)


    def forward(self, x, t):
        """
        performs the forward pass of the model,
        the input is first permuted to channels first shape,
        the input is passed through the encoder layers,
        at the bottleneck, the time embedding is added,
        then the decoder layers are applied, and the output is reshaped
        back to the original form.
        """
        # permute to channels first
        x = x.permute(0, 3, 1, 2)
        t_emb = self.time_mlp(t.float()).unsqueeze(-1).unsqueeze(-1)

        x = self.batch_norm(self.encoder(x) + t_emb)
        x = self.decoder(x)

        # permute back to original shape (256x37x3)
        x = x.permute(0,2,3,1)
        return x

In [7]:
class DDPM(nn.Module):
    """
    DDPM paper by Ho. et al: https://arxiv.org/abs/2006.11239.
    Complete implementation by Umar Jamil: https://github.com/hkproj/pytorch-ddpm/.

    A minimal implementation of the DDPM framework,
    using the UNet architecture defined above.

    self.timesteps: maximum noising / denoising timesteps.

    self.betas: linear noise schedule, used to decide the amount
    that should be added of the total sampled Gaussian noise.

    self.alphas: used to compute the cumulative product of the transformer
    noise schedule, allowing to obtain noisy input at an arbitrary timestep t
    in one step.

    self.alphas_cumprod: cumulative products of the alphas.
    """
    def __init__(self, unet, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        super(DDPM, self).__init__()

        self.unet = unet
        self.timesteps = timesteps

        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def forward_diffusion(self, x0, t, noise):
        """
        applies noise to an input x0 to reach xt, where t
        an arbitrary timestem.
        """
        t_on_cpu = t.detach().cpu()

        sqrt_alpha_cumprod = torch.sqrt(
            self.alphas_cumprod[t_on_cpu]
        ).view(-1,1,1,1).to(device)

        sqrt_one_minus_alpha_cumprod = torch.sqrt(
            1 - self.alphas_cumprod[t_on_cpu]
        ).view(-1,1,1,1).to(device)

        noisy_x = sqrt_alpha_cumprod * x0 + sqrt_one_minus_alpha_cumprod * noise
        return noisy_x

    def denoise_at_t(self, noisy_x, predicted_noise, t):
        """
        Removes given predicted noise at timestep t from the noisy input,
        and returns the denoised structure.
        """
        sqrt_alpha_cumprod = torch.sqrt(self.alphas_cumprod[t])
        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - self.alphas_cumprod[t])

        # denoised_struct = (1/sqrt_alpha_cumprod) * (
        #     noisy_x - (self.betas[t] / sqrt_one_minus_alpha_cumprod) * predicted_noise
        # )

        # workaround that is not necessarily mathematically correct, to make this process
        # output meaningful results, the formula that should be correct is the commented above
        denoised_struct = (noisy_x - sqrt_one_minus_alpha_cumprod*predicted_noise) / sqrt_alpha_cumprod
        return denoised_struct

    def get_np_structure(self, sample):
        """
        Return the np_structure to build an instance of the Protein class.
        This was already implemented in the initial notebook, but decided
        to make it part of the DDPM class for convenience.
        """
        # to make this method compatibile with tensor
        # and numpy inputs for general usability
        if isinstance(sample, torch.Tensor):
          sample = sample.detach().cpu().numpy()
        bb_atom_types = ['N', 'CA', 'C', 'O']
        bb_idx = [i for i, atom_type in enumerate(atom_types)
                  if atom_type in bb_atom_types]

        num_res = len(sample)
        nan_pos = np.isnan(sample)[..., 0]
        sample[nan_pos] = 0.
        atom_mask = np.zeros([num_res, 37])
        atom_mask[..., bb_idx] = 1

        np_sample = {
            'atom_positions': sample,
            'residue_index':np.arange(num_res),
            'atom_mask': atom_mask,
        }
        return np_sample

    def get_protein(self, x):
        """
        Returns a Protein instance from a given input.
        Expects Tensor or Numpy ndarray.
        first generates the np_structure, then returns a Protein instance.
        """
        np_sample = self.get_np_structure(x)
        # This was already implemented in the initial notebook
        prot = Protein(
            atom_positions=np_sample['atom_positions'],
            atom_mask=np_sample['atom_mask'],
            residue_index=np_sample['residue_index'],
            aatype=np.zeros([cfg.max_seq_length,], dtype=np.int32),
            chain_index=np.zeros([cfg.max_seq_length,], dtype=np.int32),
            b_factors=np.ones([cfg.max_seq_length, 37], dtype=np.int32)
        )
        return prot

    def get_predicted_protein(self, noisy_x, predicted_noise, t):
        """
        Given a noisy input (noisy_x), predicted_noise and timestep (t),
        Removes the predicted noise from the noisy input, and returns
        a Protein instance of the denoised structure.
        """
        predicted_struct = self.denoise_at_t(noisy_x, predicted_noise, t)
        prot = self.get_protein(predicted_struct)
        return prot


    def sample(self, shape, device):
        """
        Starts by sampling pure Gaussian noise,
        then iteratively denoising it at each timestep.
        Also adds an additional amount of noise (amount depends on t)
        to insure stochasticity and diversity.
        """
        predicted_noise_over_T = []
        with torch.no_grad():
            x = torch.randn(shape, device=device)
            denoised_over_T = [x]
            for t in reversed(range(self.timesteps)):
                z = torch.randn(shape, device=device) if t > 0 else 0
                tmp_t = torch.tensor(t, device=device).view(-1, 1)
                predicted_noise = self.unet(x, tmp_t.float())
                predicted_noise_over_T.append(predicted_noise)
                x = self.denoise_at_t(x, predicted_noise, t) # + torch.sqrt(self.betas[t])*z
                denoised_over_T.append(x)
        return x, denoised_over_T, predicted_noise_over_T

In [8]:
def masked_mse_loss(pred_noise, actual_noise, mask):
    """
    Masked Mean Squared Error loss.
    Used to calculate the error between the predicted noise
    and the actual noise.
    Masked to only account for valid postions: those not
    corresponding to null coordinates or padded positions.
    """
    mse_loss = nn.MSELoss(reduction='none')
    loss = mse_loss(pred_noise, actual_noise)
    loss = (loss * mask.float()).sum() / mask.sum()
    return loss

def create_scheduler(config):
    """Returns lr scheduler according to config.

    Args:
        scheduler_config (dict): _description_

    Returns:
        CyclicLR: _description_
    """
    scheduler = OneCycleLR(
        optimizer=config["optimizer"],
        max_lr=config["max_lr"],
        epochs=config["epochs"],
        steps_per_epoch=config["steps_per_epoch"],
    )
    return scheduler

In [None]:
# used for mig devices only
os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DDPM(UNet())
if torch.cuda.device_count()>1:
    model = torch.nn.DataParallel(model, device_ids=[device_id for device_id in range(torch.cuda.device_count())])
model = model.module.to(device)

lr=3e-3
epochs=50
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = masked_mse_loss

scheduler_config = {
    "optimizer": optimizer,
    "max_lr": lr,
    "epochs": epochs,
    "steps_per_epoch": len(train_loader)
}

scheduler = create_scheduler(scheduler_config)

example_proteins = {
    'original': [],
    'noisy': [],
    'predicted_noise': [],
    'timestep': []
}
metrics = {
    'step_wise_loss': [],
    'epoch_wise_loss': [],
}

model.train()
for epoch in range(epochs):
    total_loss = 0
    iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)
    for idx, batch in enumerate(iterator):
        x = batch['atom_positions'].to(device)

        # each mask is transformed into a stack of 3 copies if itself to match
        # the expected shape by the loss funtion
        mask = torch.stack([batch['atom_mask'] for _ in range(x.shape[-1])], dim=-1)
        mask = mask.to(device)

        # sample a batch of random time steps
        t = torch.randint(0, model.timesteps, [x.shape[0]], device=device, dtype=torch.long)

        # sample batch Gaussian noise
        noise = torch.randn(x.shape, device=device)

        # apply forward process
        noisy_x = model.forward_diffusion(x, t, noise)

        optimizer.zero_grad()

        # estimates the added noise at time step t using
        # the noise estimation networks
        predicted_noise = model.unet(noisy_x, t.unsqueeze(-1))

        loss = criterion(predicted_noise, noise, mask)

        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

        iterator.set_postfix(loss=f"{loss.item():.4f}")
        metrics['step_wise_loss'].append(loss.item())

        if idx % (len(train_loader) // 5) == 0:
              # at each 20% of the epoch, saves the orinal and noisy backbones,
              # also saves the predicted noise and the timestep.
              # used to later investigate and evaluate model training.
              medium_noise_t = 0
              example_proteins['original'].append(x[medium_noise_t])
              example_proteins['noisy'].append(noisy_x[medium_noise_t])
              example_proteins['predicted_noise'].append(predicted_noise[medium_noise_t])
              example_proteins['timestep'].append(t[medium_noise_t])

    metrics['epoch_wise_loss'].append(total_loss/len(train_loader))
    print(f"Epoch {epoch+1}/{epochs}, Loss: {metrics['epoch_wise_loss'][-1]:.4f}")

metrics['epoch_wise_loss'] = [metrics['step_wise_loss'][0]] + metrics['epoch_wise_loss']

Epoch 1/50: 100%|██████████| 141/141 [03:20<00:00,  1.43s/it, loss=0.3244]


Epoch 1/50, Loss: 0.5328


Epoch 2/50: 100%|██████████| 141/141 [03:18<00:00,  1.41s/it, loss=0.2939]


Epoch 2/50, Loss: 0.3073


Epoch 3/50: 100%|██████████| 141/141 [03:15<00:00,  1.38s/it, loss=0.2432]


Epoch 3/50, Loss: 0.2702


Epoch 4/50: 100%|██████████| 141/141 [03:21<00:00,  1.43s/it, loss=0.1949]


Epoch 4/50, Loss: 0.2620


Epoch 5/50: 100%|██████████| 141/141 [03:30<00:00,  1.50s/it, loss=0.2611]


Epoch 5/50, Loss: 0.2477


Epoch 6/50: 100%|██████████| 141/141 [03:34<00:00,  1.52s/it, loss=0.1930]


Epoch 6/50, Loss: 0.2372


Epoch 7/50: 100%|██████████| 141/141 [03:34<00:00,  1.52s/it, loss=0.2135]


Epoch 7/50, Loss: 0.2350


Epoch 8/50: 100%|██████████| 141/141 [03:31<00:00,  1.50s/it, loss=0.2511]


Epoch 8/50, Loss: 0.2240


Epoch 9/50: 100%|██████████| 141/141 [03:34<00:00,  1.52s/it, loss=0.2006]


Epoch 9/50, Loss: 0.2250


Epoch 10/50:  61%|██████    | 86/141 [02:08<01:19,  1.45s/it, loss=0.2249]

In [None]:
# plotting loss curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(metrics['step_wise_loss'], label="Per Step Loss", color='blue')
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.title("Per Step Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(metrics['epoch_wise_loss'], label="Per Epoch Loss", color='red')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Per Epoch Loss")
plt.legend()

plt.show()

In [None]:
# saves the model weights and the example
# proteins saved during training
# to avoid having to re-run the training
# process in case the runtime crashes
# during investigation of the results
f = open("example_proteins.pkl","wb")
pickle.dump(example_proteins,f)
f.close()

dict_to_save = {
    'model_state_dict': model.state_dict(),
}
torch.save(dict_to_save, 'model_state_dict.pt')

In [None]:
# removes the predicted noise from a noisy structure
# based on time step of a chosen index,
# to retrun the generated backbone
idx = -7
denoised = model.denoise_at_t(
    example_proteins['noisy'][idx].detach().cpu(),
    example_proteins['predicted_noise'][idx].detach().cpu(),
    example_proteins['timestep'][idx].detach().cpu()
)
sample = denoised.numpy()

In [None]:
# generate a protein instance from the given input then render it

## -> uncomment to plot the original backbone
# sample = example_proteins['original'][idx].detach().cpu().numpy()

# -> uncomment to plot the noisy backbone
# sample = example_proteins['noisy'][idx].detach().cpu().numpy()


prot = model.get_protein(sample)

pdb_str = to_pdb(prot)

# Render.
view = py3Dmol.view(
    width=600, height=600, linked=True , viewergrid=(1, 1))
view.setViewStyle({'style': 'outline', 'color': 'black', 'width': 0.1})
style = {"cartoon": {'color': 'spectrum'}}

view.addModelsAsFrames(pdb_str, viewer=(0, 0))
view.setStyle({'model': -1}, style, viewer=(0, 0))
view.zoomTo(viewer=(0, 0))

view.render()

In [None]:
# Perfom sampling (needs a fix) and renders it
model.eval()
sample, samples_over_T, noise_over_T = model.sample((1, 256, 37, 3), device)
sample = sample[0].detach().cpu().numpy()

prot = model.get_protein(sample)

pdb_str = to_pdb(prot)

# Render.
view = py3Dmol.view(
    width=600, height=600, linked=True , viewergrid=(1, 1))
view.setViewStyle({'style': 'outline', 'color': 'black', 'width': 0.1})
style = {"cartoon": {'color': 'spectrum'}}

view.addModelsAsFrames(pdb_str, viewer=(0, 0))
view.setStyle({'model': -1}, style, viewer=(0, 0))
view.zoomTo(viewer=(0, 0))

view.render()