In [227]:
import os
import gzip
import pickle
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from io import StringIO

In [231]:
class ReadGzip():
  def __init__(self, file_path):
    self.file_path = file_path
    self.gzip_file = None
    self.chunk_string = str()

  def open_gzip(self, skip_header=True):
    self.gzip_file = gzip.open(self.file_path, 'rt')
    if skip_header:
      next(self.gzip_file)

  def close_gzip(self):
    self.gzip_file.close()

  def load_next_chunk(self, lines_per_chunk):
    self.chunk_string = str()
    for _ in range(lines_per_chunk):
      self.chunk_string += self.gzip_file.readline()

    return self.chunk_string


In [276]:
class GzChunckChunker(Dataset):
  def __init__(self, file_path, nrows = 500, lines_per_chunk=10, got_header=True):
      self.file_path = file_path
      self.gzip_reader = ReadGzip(self.file_path)
      self.lines_per_chunk = lines_per_chunk
      self.chunk_lines = str()

      self.got_header = got_header
      self.nrows = nrows - self.got_header

      self.lines_returned_from_chunk = 0
      self.index_corrector = 0
      self.df = pd.DataFrame.empty
      self.sampleid = None
      self.data = None




  def loadNextChunk(self):
    # Load in next chunk of lines
    self.chunk_lines = self.gzip_reader.load_next_chunk(lines_per_chunk=self.lines_per_chunk)

    # Load in values for next chunk
    self.df = pd.read_csv(StringIO(self.chunk_lines), sep='\t', header=None)
    self.sampleid = self.df.iloc[:,0].values
    self.data = self.df.iloc[:,1:].values

  def __len__(self):
      return self.nrows

  def __getitem__(self, index):
    # If no data is loaded yet
    if index == 0:
      self.gzip_reader.open_gzip(skip_header=self.got_header)
      self.loadNextChunk()

    # Return all lines normally if lines loaded from current chunk is lower than the lines per chunk
    if self.lines_returned_from_chunk < self.lines_per_chunk:
      self.lines_returned_from_chunk += 1
      #print('index', index, 'loading', self.sampleid[index - self.index_corrector], 'corrected idx', index - self.index_corrector, 'lines_returned_from_chunk', self.lines_returned_from_chunk)

      # Close file if last line has been loaded
      if index == self.nrows - 1:
        self.gzip_reader.close_gzip()
        print('close gzipfile at idx', index)

      return self.data[index - self.index_corrector], self.sampleid[index - self.index_corrector]

    # If lines returned has reached the number of lines, load new chunk and return first line
    else:
      # Update index corrector and chunk
      self.index_corrector += self.lines_returned_from_chunk
      self.loadNextChunk()

      # Also send first line of chunk and update chunk counter
      self.lines_returned_from_chunk = 1

      # Close file if last line has been loaded
      if index == self.nrows - 1:
        self.gzip_reader.close_gzip()
        print('close gzipfile at idx', index)

      return self.data[index - self.index_corrector], self.sampleid[index - self.index_corrector]


In [282]:
#NROWS_DATASET = 500
#CHUNK_SIZE = 100
#BATCH_SIZE = 50
#gz_path = "/content/drive/MyDrive/Colab Notebooks/DeepLearning02456/Project/head500_archs4_gene_expression_norm_transposed.tsv.gz"
#GzChunks = GzChunckChunker(file_path=gz_path, nrows=NROWS_DATASET, lines_per_chunk=CHUNK_SIZE, got_header=True)

NROWS_DATASET = 5000
CHUNK_SIZE = 100
BATCH_SIZE = 50
gz_path = "/content/drive/MyDrive/Colab Notebooks/DeepLearning02456/Project/head5000_archs4.tsv.gz"
GzChunks = GzChunckChunker(file_path=gz_path, nrows=NROWS_DATASET, lines_per_chunk=CHUNK_SIZE, got_header=True)


#NROWS_DATASET = 31
#CHUNK_SIZE = 10
#BATCH_SIZE = 4
#gz_path ="/content/drive/MyDrive/Colab Notebooks/DeepLearning02456/Project/test.tsv.gz"
#GzChunks = GzChunckChunker(file_path=gz_path, nrows=NROWS_DATASET, lines_per_chunk=CHUNK_SIZE, got_header=True)

loader = DataLoader(GzChunks, batch_size=BATCH_SIZE)
for test_data, sampleidd in loader:
  print('test_data', test_data)
  print('sample id', sampleidd)
  break


test_data tensor([[2.8080, 0.1467, 1.5222,  ..., 7.9495, 1.2313, 2.5236],
        [2.3922, 0.0362, 7.5704,  ..., 6.7737, 1.1999, 2.2073],
        [4.2884, 0.7774, 1.5978,  ..., 6.4182, 4.4427, 4.7735],
        ...,
        [4.1706, 0.0000, 8.2280,  ..., 7.9480, 3.6240, 4.1708],
        [4.1144, 0.2157, 3.3651,  ..., 8.5822, 6.3513, 6.6163],
        [3.6734, 0.0696, 0.0000,  ..., 5.9741, 3.3167, 5.2970]],
       dtype=torch.float64)
sample id ('GSM4747249', 'GSM4063503', 'GSM4631135', 'GSM5374404', 'GSM4891113', 'GSM1900353', 'GSM4664285', 'GSM5009234', 'GSM3141844', 'GSM4664679', 'GSM2687514', 'GSM2664079', 'GSM5099597', 'GSM3188526', 'GSM2109424', 'GSM5221397', 'GSM4024072', 'GSM1636860', 'GSM4182335', 'GSM2109332', 'GSM4162785', 'GSM4929250', 'GSM2717685', 'GSM5370056', 'GSM4587542', 'GSM4451386', 'GSM3305408', 'GSM3324233', 'GSM3184674', 'GSM2523146', 'GSM1907124', 'GSM1370700', 'GSM4543024', 'GSM5196521', 'GSM3388297', 'GSM5236702', 'GSM3689866', 'GSM2667617', 'GSM4585241', 'GSM407