<font size="2"> To install new library (since everything is running on kernel python3.10): 
/$ python3.10 -m pip install *pandas*
</font>

In [None]:
# import all files from necessary directory
from os import listdir
import os
from os.path import isfile, join
dir = "/data/Riabova/train3"
all_files = [join(dir, f) for f in listdir(dir)]
all_files_mhd = [string for string in all_files if string.endswith("mhd")]
len(all_files_mhd)

In [None]:
import numpy as np

''' 
Example usage:
    coords_1 = [x1, x2, x3]  # Replace x1, x2, x3 with the actual values
    method = "hex2us"  # Or "us2hex"
    isShifted = True  # Or False
    coords_2 = transformTo(coords_1, method, isShifted) 

'''

def transformTo(coords_1, method, isShifted):
    # Transformation matrices
    T_needle_tip = np.array([[-0.1601, 3.1823, 0.1383, 13.8778],
                            [0.0178, 0.0806, -3.3695, 112.5630],
                            [-2.4504, -0.1146, -0.3139, 110.3655],
                            [-0.0000, -0.0000, -0.0000, 1.0000]])

    T_needle_tip_shifted = np.array([[-0.1601, 3.1823, 0.1383, 20.8778],
                                     [0.0178, 0.0806, -3.3695, 109.5630],
                                     [-2.4504, -0.1146, -0.3139, 110.3655],
                                     [-0.0000, -0.0000, -0.0000, 1.0000]])

    if isShifted:
        T = T_needle_tip_shifted
    else:
        T = T_needle_tip

    coords_1 = np.append(coords_1, 1)  # Add 1 to coords_1

    if method == "hex2us":
        coords_2 = np.dot(T, coords_1.T)
    else:
        coords_2 = np.linalg.solve(T, coords_1.T)  # Equivalent to inv(T) * coords_1.T

    coords_2 = coords_2.T
    coords_2 = coords_2[:-1]

    return coords_2


def transformToUS(coords_1):
    return transformTo(coords_1, "hex2us", True)

In [None]:
import numpy as np

def get_labels(f, info):
    
    labels = np.zeros((info['Dimensions'][3], 3))
        
    for frame in range(0, info['Dimensions'][3]):
        if frame > 9:
            info_frame_timestamp_field = "Seq_Frame00" + str(frame) + "_Timestamp"
        else:
            info_frame_timestamp_field = "Seq_Frame000" + str(frame) + "_Timestamp"
            
        frame_timestamp = float(info[info_frame_timestamp_field])
        
        # filename in format 'x_y_z_alpha_beta_gamma_velocity_axis_timestamp
        filename = f.split("/")[-1].split(".")[0] # 0_22_5_0_0_0_3_1_1686308127391259
        frame_timestamp_micros = frame_timestamp*1000000 #in microseconds, correct
        
         # extract params from filename
        filename_params = filename.split('_')
        start_timestamp_micros = float(filename_params[-1]) # in micros
        axis = int(filename_params[-2])
        velocity = int(filename_params[-3])  # in mm/s

        # get tip coords in hexapode system
        dt = frame_timestamp_micros - start_timestamp_micros  # in micros
        distance = velocity * dt / 1000000  # in mm
        hex_coords = [0, 0, 0]
        for i in range(3):
            hex_coords[i] = float(filename_params[i])
            if axis == i+1:
                hex_coords[i] += distance
                        
        # get tip coords in US system
        labels[frame, :] = transformToUS(hex_coords)
        
    return labels

In [None]:
from utils.type_reader import mha_read_header
import numpy as np
from tqdm import tqdm

# merge all images in one huge array

all_frames_filenames_array = np.empty((1))
frame_nums = np.empty((1))
labels = np.empty((1, 3))
i = 0
for f in all_files_mhd:
    if (i%20 == 0):
        print("File "+str(i)+"...")
    info = mha_read_header(f)
    labels = np.concatenate((labels, get_labels(f, info)), axis = 0)
    all_frames_filenames_array = np.concatenate((all_frames_filenames_array, [f]*info['Dimensions'][3]), axis=0)
    frame_nums = np.concatenate((frame_nums, np.arange(0, info['Dimensions'][3])), axis=0)
    i+=1
    
all_frames_filenames_array = all_frames_filenames_array[ 1:]
labels = labels[1:, :]
frame_nums = frame_nums[1:]

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
from PIL import Image
from utils.type_reader import get_image_array

#Check that frames are being displayed correctly
input_image = get_image_array(all_files_mhd[0])
plt.imshow(input_image[:, :, 69, 3])

In [None]:
X = np.vstack((all_frames_filenames_array, frame_nums)).transpose()
y = labels
print(X.shape)
print(X[0, :])
print(y[0, :])

<font size="4">Setting parameters:</font>

In [None]:
SPLIT_PERCENT = 0.2
IMG_SIZE = 135
EPOCHS = 20
BATCH_TRAIN = 5
BATCH_VALID = 5

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=SPLIT_PERCENT, random_state=42,shuffle=True)

In [None]:
from utils.datasets import ImageDataset
import torch
train_dataset = ImageDataset(X_train,y_train)
valid_dataset = ImageDataset(X_valid,y_valid)
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_TRAIN,shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset,batch_size=BATCH_VALID,shuffle=True)

In [None]:
from utils.display_utils import show_sample

show_sample(train_dataset[0])

In [None]:
####################
#TODO: develop torch model and fit it
####################