<a href="https://colab.research.google.com/github/radhikasethi2011/ProteinClassify/blob/main/colab/creating_pdb_class_tfrecords_(220).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/radhikasethi2011/ProteinClassify.git

In [None]:
import numpy as np
import pandas as pd
import os
from statistics import mode
import glob as glob

In [None]:
scop_cla = pd.read_csv('/content/ProteinClassify/scop-cla-latest.txt', header=None, skiprows=7, delimiter = ' ')
cols = [0,3,4,5,6,7,8,9]
scop_cla.drop(scop_cla.columns[cols], axis= 1, inplace=True)
scop_cla.rename(columns={ scop_cla.columns[0]: "residue" , 
                         scop_cla.columns[1]: "chain",
                         scop_cla.columns[2]: "label" }, inplace=True)
scop_cla['chain'] = scop_cla['chain'].str.split(':').str[0]
scop_cla['label'] = scop_cla['label'].str.split(',').str[1].str.split('=').str[1]
scop_cla['residue'] = scop_cla['residue'] + '_' + scop_cla['chain']
#scop_cla['residue'] = scop_cla['residue'].str[1:]

#scop_cla.set_index(['residue'], inplace=True)
scop_cla

In [None]:
cols = ['domid','pdbid', 'pdbchain']
scop_struct = pd.read_csv('/content/ProteinClassify/scop-represented-structures-latest.txt',
                   header = None, skiprows=6, names=cols, delimiter = ' ')
scop_struct['pdbid'] = scop_struct['pdbid'] + '_' + scop_struct['pdbchain']
scop_struct

In [None]:
mask = scop_cla['residue'] == '5FLV_E'
scop_cla[mask]['label']

In [None]:
ans = set(list(scop_cla['residue'])).intersection(list(scop_struct['pdbid']))
len(ans)

In [None]:
#scop_cla.set_index(['residue'], inplace=True)
#scop_struct.set_index(['pdbid'], inplace=True)
merged = scop_struct.merge(scop_cla, left_on='pdbid', right_on='residue')
merged = merged.drop_duplicates(subset=['pdbid'],keep='first')
merged.drop(columns=['pdbid','pdbchain'], axis=1, inplace=True)
merged['residue'] = merged['residue'].str.split('_').str[0]
merged.reset_index(drop=True, inplace=True)
merged.to_csv('pdb_chain_class.csv', sep=',')
merged.set_index(['residue'], inplace=True)

merged

In [None]:
mask = merged.index == '5FLV'
merged[mask]['label']

In [None]:
merged[mask]['chain']

In [None]:
%mkdir /content/pdbs
%mkdir /content/ca_csv
%mkdir /content/records

In [None]:
def is_atom_record(record):
  if record.startswith('ATOM'): 
    return True;
  return False;

def is_intended_chain(record, chain):
  if record[21] == chain: 
    return True
  return False

def is_chain_ter_record(record, chain):
  if record[21] == chain and record.startswith("TER"): 
    return True
  return False

def is_alt_record(record):
  if record[16] == " ": 
    return False
  return True
  
def is_ca_atom(record):
  if record[12:15].strip() == "CA":
    return True
  return False

def parse_atom_records(record):
  atom = record[12:16].strip()
  residue = record[17:20].strip()
  chain = record[21].strip()
  seq_pos = record[22:26].strip()
  x = record[30:38].strip()
  y = record[38:46].strip()
  z = record[46:54].strip()
  return atom, residue, chain, seq_pos, x, y, z

def parse_pdb(contents, chainl): #chain: list
  ca_records = {}
  pos = 0
  for i in range(len(chainl)):
    for line in contents:
      if line.startswith("ENDMDL"): break;
      if is_chain_ter_record(line, chainl[i]): break;

      if not is_atom_record(line): continue
      if not is_intended_chain(line, chainl[i]): continue
      if is_alt_record(line): continue 
      if not is_ca_atom(line): continue

      (atom, residue, chain, rel_pos, x, y, z) = parse_atom_records(line)
      ca_records[pos] = atom, residue, chain, rel_pos, x, y, z
      pos+=1
  return ca_records

def dump_records_to_csv_file(file_name, record_dict):
  file = open(file_name, 'w')
  for key in record_dict.keys():
    (atom, residue, chain, pos, x, y, z) = record_dict[key]
    file.write(f"%s,%s,%s,%s,%s,%s,%s\n"%(atom, residue, chain, pos, x, y, z))


In [None]:
import tensorflow as tf 

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


In [None]:
def serialize_example(feature0, feature1, feature3, 
                      feature4, feature5, feature6):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'label': _bytes_feature(feature0),
      'residue': _bytes_feature(feature1),
      'pos': _int64_feature(feature3),
      'x': _float_feature(feature4),
      'y': _float_feature(feature5),
      'z': _float_feature(feature6)
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()




In [None]:
#filename = 'test.tfrecord'
def write_to_testrecord(filename, df, label):
  with tf.io.TFRecordWriter(filename) as writer:
    for i in range(len(df)):
      serialized_example = serialize_example(bytes(label, 'utf-8' ),
                                            bytes(df['atom'][i], 'utf-8'), 
                                            df['pos'][i], df['x'][i], 
                                            df['y'][i], df['z'][i])
      example_proto = tf.train.Example.FromString(serialized_example)
      #writer = tf.data.experimental.TFRecordWriter(filename)
      writer.write(serialized_example)

In [None]:
def edit_pos(csv_name):
  col_names = ['ca','atom','residue','pos','x','y','z']
  df = pd.read_csv(f'/content/ca_csv/{csv_name}.csv', names=col_names, sep=',')
  x = df['x'][0]
  y = df['y'][0]
  z = df['z'][0]
  df['x'] = df['x'] - x
  df['y'] = df['y'] - y
  df['z'] = df['z'] - z
  return df




In [None]:
pd_list = list(merged.index)
for i in range(len(pd_list)):
  pd_id = pd_list[i]
  #print(i, pd_id)
  !wget -q https://files.rcsb.org/download/{pd_id}.pdb -O /content/pdbs/{pd_id}.pdb
  size = os.path.getsize(f"/content/pdbs/{pd_id}.pdb")
  if size>0:
    chainl = list(merged['chain'][pd_id])
    print("i: ", i, pd_id, chainl)
    with open(f'/content/pdbs/{pd_id}.pdb') as f:
      contents = f.readlines() 
    ca_records = parse_pdb(contents, chainl) 
    dump_records_to_csv_file(f'/content/ca_csv/{pd_id}.csv', ca_records)
    df = edit_pos(pd_id)
    label = merged['label'][pd_id]
    if(type(label) is str):
      filename = '/content/records/' + f'{pd_id}' + '_' + f'{label}' + '.tfrecord'
      write_to_testrecord(filename, df, label)
    else:
      for l in label: 
        filename = '/content/records/' + f'{pd_id}' + '_' + f'{l}' + '.tfrecord'
        write_to_testrecord(filename, df, l)
  else: print("! ! ! ! file not found at rcsb ! ! ! !")


  


In [None]:
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

for raw_record in raw_dataset.take(10):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)


In [None]:
!tar -zcvf records.tar.gz /content/records 