In [1]:
import torch

if torch.cuda.is_available():
    print("CUDA is available")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available")

CUDA is available
Device count: 8
Device name: Tesla V100-SXM2-32GB


In [2]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)

if torch.cuda.is_available():
    print("CUDA is available")
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available")

PyTorch version: 2.4.0
CUDA version: 12.1
CUDA is available
Device 0: Tesla V100-SXM2-32GB
Device 1: Tesla V100-SXM2-32GB
Device 2: Tesla V100-SXM2-32GB
Device 3: Tesla V100-SXM2-32GB
Device 4: Tesla V100-SXM2-32GB
Device 5: Tesla V100-SXM2-32GB
Device 6: Tesla V100-SXM2-32GB
Device 7: Tesla V100-SXM2-32GB


In [3]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [4]:
# Path to the CSV file
csv_file_for_labels = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/ptbxl_train_label_df.csv'
# Path to the image directory
data_dir = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records100_ground_truth'

# Load the CSV file
label_df = pd.read_csv(csv_file_for_labels)
train_df = label_df.sample(frac = 0.8)
test_df = label_df.drop(train_df.index)
label_df.head()

Unnamed: 0.2,Unnamed: 0.1,index,Unnamed: 0,ecg_id,patient_id,filename_lr,filename_hr,Normal_ECG,ecg_lr_path
0,0,108,108,109,21312.0,records100/00000/00109_lr,records500/00000/00109_hr,True,00109_lr
1,1,19314,19314,19353,19389.0,records100/19000/19353_lr,records500/19000/19353_hr,False,19353_lr
2,2,12707,12707,12739,16579.0,records100/12000/12739_lr,records500/12000/12739_hr,True,12739_lr
3,3,18414,18414,18453,21182.0,records100/18000/18453_lr,records500/18000/18453_hr,False,18453_lr
4,4,10879,10879,10906,14854.0,records100/10000/10906_lr,records500/10000/10906_hr,True,10906_lr


In [5]:
len(train_df)

12556

In [6]:
len(test_df)

3139

In [36]:
# Define transformations
img_size = 224  # or whatever size you want
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    normalize,
])

In [37]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class ECGImageDataset(Dataset):
    def __init__(self, label_df, image_dir, transform=None):
        self.label_df = label_df
        self.image_dir = image_dir
        self.transform = transform
        self.image_paths = self._get_image_paths()

    def _get_image_paths(self):
        image_paths = []
        for root, _, files in os.walk(self.image_dir):
            for file in files:
                if file.endswith('.png'):
                    image_paths.append(os.path.join(root, file))
        return image_paths

    def __len__(self):
        return len(self.label_df)

    def __getitem__(self, idx):
        img_name = self.label_df.iloc[idx]['ecg_lr_path'] + '-0.png'
        matching_paths = [path for path in self.image_paths if img_name in path]
        
        # Use the first match if it exists
        img_path = matching_paths[0] if matching_paths else None
        
        while img_path is None:
            idx += 1
            img_name = self.label_df.iloc[idx]['ecg_lr_path'] + '-0.png'
            matching_paths = [path for path in self.image_paths if img_name in path]
            # Use the first match if it exists
            img_path = matching_paths[0] if matching_paths else None

        image = Image.open(img_path).convert('RGB')
        label = self.label_df.iloc[idx]['Normal_ECG']
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

train_dataset = ECGImageDataset(train_df, data_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)

test_dataset = ECGImageDataset(test_df, data_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=False)

# Example of iterating through the dataloader
for images, labels in train_loader:
    print(images.shape, labels.shape)
    break


torch.Size([32, 3, 224, 224]) torch.Size([32])


In [38]:
print(labels.int())

tensor([0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1,
        1, 1, 1, 1, 0, 0, 1, 1], dtype=torch.int32)


### Creating information dfs for 5 class classification

In [2]:
import pandas as pd
import ast
import numpy as np

scp_statements_path = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/scp_statements.csv'
database_path = '../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/ptbxl_database.csv'

In [3]:
# scp statements file
df = pd.read_csv(scp_statements_path)

In [4]:
df.head()

Unnamed: 0.1,Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
0,NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
1,NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
2,DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
3,LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
4,NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7


In [5]:
df['diagnostic_class'].unique()

array(['STTC', 'NORM', 'MI', 'HYP', 'CD', nan], dtype=object)

In [6]:
# database
df2 = pd.read_csv(database_path)
df2.head()

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,...,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,...,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr
2,3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,...,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr
3,4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,...,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr
4,5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,...,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr


In [7]:
print(len(df2))
print(len(df2.loc[df2['validated_by_human'] == True]))

21799
16056


In [8]:
# load and convert annotation data
Y = pd.read_csv(database_path, index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

In [9]:
Y.scp_codes

ecg_id
1                 {'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}
2                             {'NORM': 80.0, 'SBRAD': 0.0}
3                               {'NORM': 100.0, 'SR': 0.0}
4                               {'NORM': 100.0, 'SR': 0.0}
5                               {'NORM': 100.0, 'SR': 0.0}
                               ...                        
21833    {'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...
21834             {'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}
21835                           {'ISCAS': 50.0, 'SR': 0.0}
21836                           {'NORM': 100.0, 'SR': 0.0}
21837                           {'NORM': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21799, dtype: object

In [10]:
# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

In [11]:
Y.head(20)

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM]
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM]
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM]
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM]
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM]
6,19005.0,18.0,1,,58.0,2.0,0.0,CS-12 E,1984-11-28 13:32:13,sinusrhythmus normales ekg,...,", V1",,,,,,4,records100/00000/00006_lr,records500/00000/00006_hr,[NORM]
7,16193.0,54.0,0,,83.0,2.0,0.0,CS-12 E,1984-11-28 13:32:22,"sinusrhythmus linkstyp t abnormal, wahrscheinl...",...,,,,,,,7,records100/00000/00007_lr,records500/00000/00007_hr,[NORM]
8,11275.0,48.0,0,,95.0,2.0,0.0,CS-12 E,1984-12-01 14:49:52,sinusrhythmus linkstyp qrs(t) abnormal infe...,...,", II,AVF",", I-AVF,",,,,,9,records100/00000/00008_lr,records500/00000/00008_hr,[MI]
9,18792.0,55.0,0,,70.0,2.0,0.0,CS-12 E,1984-12-08 09:44:43,sinusrhythmus normales ekg,...,,", I-AVR,",,,,,10,records100/00000/00009_lr,records500/00000/00009_hr,[NORM]
10,9456.0,22.0,1,,56.0,2.0,0.0,CS-12 E,1984-12-12 14:12:46,sinusrhythmus normales ekg,...,,,,,,,9,records100/00000/00010_lr,records500/00000/00010_hr,[NORM]


In [12]:
# Split data into train and test
test_fold = 10

# Train
y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass
# Test
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass

In [13]:
# Filter to get only elements with one class
y_train_single_class = y_train[y_train.apply(lambda x: len(x) == 1)]

# Get unique classes in the filtered elements
unique_classes = np.unique(y_train_single_class)

print(unique_classes)
len(y_train_single_class)

[list(['CD']) list(['HYP']) list(['MI']) list(['NORM']) list(['STTC'])]


14594

In [14]:
y_train_single_class[:30]

ecg_id
1     [NORM]
2     [NORM]
3     [NORM]
4     [NORM]
5     [NORM]
6     [NORM]
7     [NORM]
8       [MI]
10    [NORM]
11    [NORM]
12    [NORM]
13    [NORM]
14    [NORM]
15    [NORM]
16    [NORM]
19    [NORM]
21    [NORM]
22    [STTC]
24    [NORM]
25    [NORM]
26    [STTC]
27    [NORM]
28    [STTC]
29    [NORM]
30     [HYP]
31    [NORM]
32      [CD]
33    [NORM]
35    [NORM]
36    [NORM]
Name: diagnostic_superclass, dtype: object

In [15]:
from sklearn.preprocessing import LabelEncoder
# Flatten the list structure
y_train_flat = y_train_single_class.apply(lambda x: x[0])

# Initialize the label encoder
label_encoder = LabelEncoder()

# Fit the label encoder and transform the labels to integer encoded labels
y_train_encoded = label_encoder.fit_transform(y_train_flat)

print("Integer Encoded Labels: ", y_train_encoded[:30])

Integer Encoded Labels:  [3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 3 3 4 3 3 4 3 4 3 1 3 0 3 3 3]


In [16]:
# Print the mapping of integers to original labels
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Mapping: ", label_mapping)

Label Mapping:  {'CD': 0, 'HYP': 1, 'MI': 2, 'NORM': 3, 'STTC': 4}


In [17]:
# examples[0].shape

# Create test set file

In [85]:
import os
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

test = True
freq = 500

# Load the database file
ptb_xl_database_df = pd.read_csv(database_path, index_col='ecg_id')
ptb_xl_database_df.scp_codes = ptb_xl_database_df.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
ptb_xl_database_df['diagnostic_superclass'] = ptb_xl_database_df.scp_codes.apply(aggregate_diagnostic)
Y = ptb_xl_database_df

# Split data into train and test
test_fold = 10
y_train = Y[Y.strat_fold != test_fold]
y_test = Y[Y.strat_fold == test_fold]

if test:
    y = y_test
else:
    y = y_train
    
if freq == 100:
    y_file_names = y.filename_lr.apply(lambda x: x.split('/')[-1])
else:
    y_file_names = y.filename_hr.apply(lambda x: x.split('/')[-1])

# Filter to get only elements with one class
y_single_class = y[y.diagnostic_superclass.apply(lambda x: len(x) == 1)]
y_file_names = y_file_names[y.diagnostic_superclass.apply(lambda x: len(x) == 1)]

# Flatten the list structure
y_single_class_flat = y_single_class.diagnostic_superclass.apply(lambda x: x[0])

# Initialize the label encoder
label_encoder = LabelEncoder()

# Fit the label encoder and transform the labels to integer encoded labels
y_encoded = label_encoder.fit_transform(y_single_class_flat)

# Print the mapping of integers to original labels
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Mapping: ", label_mapping)

y_labels = y_encoded
y_paths = y_file_names.loc[y_single_class.index]
# Reset index 
y_paths.reset_index(drop=True, inplace=True)
y_paths.index += 0  # Update index to start from 0
y_paths.index.name = 'index'

Label Mapping:  {'CD': 0, 'HYP': 1, 'MI': 2, 'NORM': 3, 'STTC': 4}


In [86]:
len(y_paths)

1650

In [87]:
len(y_labels)

1650

In [88]:
import os
if freq == 100:
    img_dir = "../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records100_ground_truth"
else:
    img_dir = "../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records500_ground_truth"
image_paths = []
for root, _, files in os.walk(img_dir):
    for file in files:
        if file.endswith('.png'):
            image_paths.append(os.path.join(root, file))
# image_paths

In [89]:
# Create an empty list to store the matched paths and labels
data = []

# Iterate over y_paths and image_paths to find matches and create rows for the dataframe
for i in range(len(y_paths)):
    y_path = y_paths[i]
    for j in range(len(image_paths)):
        img_path = image_paths[j]
        if y_path in img_path:
            data.append([img_path, y_labels[i]])

# Convert the list to a pandas DataFrame
df = pd.DataFrame(data, columns=['Image Path', 'Label'])

# Print the first few rows of the DataFrame to verify
print(df.head())
print(f"Total matches: {len(df)}")

                                          Image Path  Label
0  ../../../../../data/padmalab_external/special_...      3
1  ../../../../../data/padmalab_external/special_...      3
2  ../../../../../data/padmalab_external/special_...      3
3  ../../../../../data/padmalab_external/special_...      3
4  ../../../../../data/padmalab_external/special_...      3
Total matches: 571


In [90]:
df.head()

Unnamed: 0,Image Path,Label
0,../../../../../data/padmalab_external/special_...,3
1,../../../../../data/padmalab_external/special_...,3
2,../../../../../data/padmalab_external/special_...,3
3,../../../../../data/padmalab_external/special_...,3
4,../../../../../data/padmalab_external/special_...,3


In [91]:
df.to_csv(f'test-{freq}HZ-files-and-labels.csv', index=False)

# Create train set

In [92]:
import os
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

test = False
freq = 500

# Load the database file
ptb_xl_database_df = pd.read_csv(database_path, index_col='ecg_id')
ptb_xl_database_df.scp_codes = ptb_xl_database_df.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(scp_statements_path, index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# Apply diagnostic superclass
ptb_xl_database_df['diagnostic_superclass'] = ptb_xl_database_df.scp_codes.apply(aggregate_diagnostic)
Y = ptb_xl_database_df

# Split data into train and test
test_fold = 10
y_train = Y[Y.strat_fold != test_fold]
y_test = Y[Y.strat_fold == test_fold]

if test:
    y = y_test
else:
    y = y_train

if freq == 100:
    y_file_names = y.filename_lr.apply(lambda x: x.split('/')[-1])
else:
    y_file_names = y.filename_hr.apply(lambda x: x.split('/')[-1])

# Filter to get only elements with one class
y_single_class = y[y.diagnostic_superclass.apply(lambda x: len(x) == 1)]
y_file_names = y_file_names[y.diagnostic_superclass.apply(lambda x: len(x) == 1)]

# Flatten the list structure
y_single_class_flat = y_single_class.diagnostic_superclass.apply(lambda x: x[0])

# Initialize the label encoder
label_encoder = LabelEncoder()

# Fit the label encoder and transform the labels to integer encoded labels
y_encoded = label_encoder.fit_transform(y_single_class_flat)

# Print the mapping of integers to original labels
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
print("Label Mapping: ", label_mapping)

y_labels = y_encoded
y_paths = y_file_names.loc[y_single_class.index]
# Reset index 
y_paths.reset_index(drop=True, inplace=True)
y_paths.index += 0  # Update index to start from 0
y_paths.index.name = 'index'

Label Mapping:  {'CD': 0, 'HYP': 1, 'MI': 2, 'NORM': 3, 'STTC': 4}


In [93]:
len(y_paths)

14594

In [94]:
import os
if freq == 100:
    img_dir = "../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records100_ground_truth"
else:
    img_dir = "../../../../../data/padmalab_external/special_project/physionet.org/files/ptb-xl/1.0.3/records500_ground_truth"
image_paths = []
for root, _, files in os.walk(img_dir):
    for file in files:
        if file.endswith('.png'):
            image_paths.append(os.path.join(root, file))
# image_paths

In [95]:
# Create an empty list to store the matched paths and labels
data = []

# Iterate over y_paths and image_paths to find matches and create rows for the dataframe
for i in range(len(y_paths)):
    y_path = y_paths[i]
    for j in range(len(image_paths)):
        img_path = image_paths[j]
        if y_path in img_path:
            data.append([img_path, y_labels[i]])

# Convert the list to a pandas DataFrame
df = pd.DataFrame(data, columns=['Image Path', 'Label'])

# Print the first few rows of the DataFrame to verify
print(df.head())
print(f"Total matches: {len(df)}")

                                          Image Path  Label
0  ../../../../../data/padmalab_external/special_...      3
1  ../../../../../data/padmalab_external/special_...      3
2  ../../../../../data/padmalab_external/special_...      3
3  ../../../../../data/padmalab_external/special_...      3
4  ../../../../../data/padmalab_external/special_...      3
Total matches: 4030


In [96]:
for i in range(4030):
    if 'checkpoint' in df.iloc[i]["Image Path"]:
        print("contains")
        print(i)

contains
1


In [97]:
updated_df = df.drop([1])

In [98]:
len(updated_df)

4029

In [99]:
updated_df.to_csv(f'train-{freq}HZ-files-and-labels.csv', index=False)

In [101]:
load_test = pd.read_csv('train-500HZ-files-and-labels.csv')
load_test.head()

Unnamed: 0,Image Path,Label
0,../../../../../data/padmalab_external/special_...,3
1,../../../../../data/padmalab_external/special_...,3
2,../../../../../data/padmalab_external/special_...,3
3,../../../../../data/padmalab_external/special_...,3
4,../../../../../data/padmalab_external/special_...,3


In [71]:
import os
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

# Define transformations
img_size = 224  # or whatever size you want
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class ECGImageDataset(Dataset):
    def __init__(self, info_df_path, transform=None):
        self.info_df = pd.read_csv(info_df_path)
        self.transform = transform

    def __len__(self):
        return len(self.info_df)

    def __getitem__(self, idx):
        img_path = self.info_df.iloc[idx]['Image Path']
        image = Image.open(img_path).convert('RGB')
        label = self.info_df.iloc[idx]['Label']

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label).long()

# Initialize dataset and dataloader for training
train_dataset = ECGImageDataset('train-100HZ-files-and-labels.csv', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=200)
print(len(train_dataset))

# Print the shapes of the examples and labels from the train dataloader
from tqdm import tqdm
for examples, labels in tqdm(train_dataloader):
    print(examples.shape, labels.shape)