In [None]:
"""
Implement a SnapshotArray that supports the following interface:

SnapshotArray(int length) initializes an array-like data structure with the given length. Initially, each element equals 0.
void set(index, val) sets the element at the given index to be equal to val.
int snap() takes a snapshot of the array and returns the snap_id: the total number of times we called snap() minus 1.
int get(index, snap_id) returns the value at the given index, at the time we took the snapshot with the given snap_id


Example 1:
    Input: 
        ["SnapshotArray","set","snap","set","get"]
        [[3],[0,5],[],[0,6],[0,0]]
    Output: 
        [null,null,0,null,5]
    Explanation: 
        SnapshotArray snapshotArr = new SnapshotArray(3); // set the length to be 3
        snapshotArr.set(0,5);  // Set array[0] = 5
        snapshotArr.snap();  // Take a snapshot, return snap_id = 0
        snapshotArr.set(0,6);
        snapshotArr.get(0,0);  // Get the value of array[0] with snap_id = 0, return 5

Constraints:
    1 <= length <= 5 * 104
    0 <= index < length
    0 <= val <= 109
    0 <= snap_id < (the total number of times we call snap())
    At most 5 * 104 calls will be made to set, snap, and get.
"""

# In the right direction;
# Only needed to think of optimizing get using binary search
import bisect

class SnapshotArray:
    def __init__(self, length: int):
        self.vals = [[[0, 0]] for i in range(length)]
        self.snap_idx = 0

    def set(self, index: int, val: int) -> None:
        if self.vals[index][-1][0] == self.snap_idx:
            self.vals[index][-1][1] = val
        else:
            self.vals[index].append([self.snap_idx, val])

    def snap(self) -> int:
        self.snap_idx += 1
        return self.snap_idx - 1

    def get(self, index: int, snap_id: int) -> int:
        nums = self.vals[index]
        snap = bisect.bisect_right(nums, [snap_id, 10**9])
        return self.vals[index][snap-1][1]

# Still taking too much memory, time
class SnapshotArray:
    def __init__(self, length: int):
        self.vals = [[[0, 0]] for i in range(length)]
        self.snap_idx = 0

    def set(self, index: int, val: int) -> None:
        _, idx_limit = self.vals[index][-1]
        if idx_limit == self.snap_idx:
            self.vals[index][-1][0] = val
        else:
            self.vals[index][-1][1] = self.snap_idx
            self.vals[index].append([val, self.snap_idx+1])

    def snap(self) -> int:
        snap_id = self.snap_idx
        self.snap_idx += 1
        return snap_id

    def get(self, index: int, snap_id: int) -> int:
        for val, snap_lim in self.vals[index]:
            if snap_lim >= snap_id:
                return val
        return self.vals[index][-1][0]
            

# Still taking too much memory, time
class SnapshotArray:
    def __init__(self, length: int):
        self.vals = [[0] for i in range(length)]
        self.snap_called = 1

    def set(self, index: int, val: int) -> None:
        self.vals[index][0] = val

    def snap(self) -> int:
        snap_id = self.snap_called - 1
        for val, snaps in self.vals:
            snaps.append(val)
        self.snap_called += 1
        return snap_id

    def get(self, index: int, snap_id: int) -> int:
        return self.vals[index][1][snap_id]

# Your SnapshotArray object will be instantiated and called as such:
# obj = SnapshotArray(length)
# obj.set(index,val)
# param_2 = obj.snap()
# param_3 = obj.get(index,snap_id)

# Memory error -- SC a lot
class SnapshotArray:
    def __init__(self, length: int):
        self.vals = [0]*length
        self.snaps = {}
        self.snap_called = 1

    def set(self, index: int, val: int) -> None:
        self.vals[index] = val

    def snap(self) -> int:
        snap_id = self.snap_called - 1
        self.snaps[snap_id] = self.vals[::1]
        self.snap_called += 1
        return snap_id

    def get(self, index: int, snap_id: int) -> int:
        return self.snaps[snap_id][index]